diff --git a/.compatibility b/.compatibility
index c8ac4083d2a2b2985a02b6ec281a0e33ab4f23b2..32da32be5521539f90b6198e72089e14f8d0bd8e 100644
--- a/.compatibility
+++ b/.compatibility
@@ -1,3 +1,3 @@
1.12.0-11.3.0
-1.11.0-11.3.0
-1.10.1-11.3.0
+1.13.0-11.6.0
+2.0.0-11.7.0
diff --git a/.coveragerc b/.coveragerc
new file mode 100644
index 0000000000000000000000000000000000000000..b065e6eb9b772cc3a0253622c11986ee8e5613eb
--- /dev/null
+++ b/.coveragerc
@@ -0,0 +1,4 @@
+[run]
+concurrency = multiprocessing
+parallel = true
+sigterm = true
diff --git a/.flake8 b/.flake8
deleted file mode 100644
index 229856aa4366c092e94eb573dbbd67756974a3ae..0000000000000000000000000000000000000000
--- a/.flake8
+++ /dev/null
@@ -1,22 +0,0 @@
-[flake8]
-ignore =
- ;W503 line break before binary operator
- W503,
- ;E203 whitespace before ':'
- E203,
-
-; exclude file
-exclude =
- .tox,
- .git,
- __pycache__,
- build,
- dist,
- *.pyc,
- *.egg-info,
- .cache,
- .eggs
-
-max-line-length = 120
-
-per-file-ignores = __init__.py:F401
diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml
index b310fcfefc154e3c8b3e941a363181743c95b5ab..436bdf887c69326dba7b758db008cc1b1e1f6551 100644
--- a/.github/ISSUE_TEMPLATE/config.yml
+++ b/.github/ISSUE_TEMPLATE/config.yml
@@ -1,7 +1,7 @@
blank_issues_enabled: true
contact_links:
- name: ❓ Simple question - Slack Chat
- url: https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w
+ url: https://github.com/hpcaitech/public_assets/tree/main/colossalai/contact/slack
about: This issue tracker is not for technical support. Please use our Slack chat, and ask the community for help.
- name: ❓ Simple question - WeChat
url: https://github.com/hpcaitech/ColossalAI/blob/main/docs/images/WeChat.png
diff --git a/.github/workflows/README.md b/.github/workflows/README.md
index a46d8b1c24d05804320acf09eff25fd2fcab3fa9..3fad7e36f14c6473d403197534bbf26a93db8095 100644
--- a/.github/workflows/README.md
+++ b/.github/workflows/README.md
@@ -14,7 +14,7 @@
- [Compatibility Test on Dispatch](#compatibility-test-on-dispatch)
- [Release](#release)
- [User Friendliness](#user-friendliness)
- - [Commmunity](#commmunity)
+ - [Community](#community)
- [Configuration](#configuration)
- [Progress Log](#progress-log)
@@ -43,10 +43,18 @@ I will provide the details of each workflow below.
| Workflow Name | File name | Description |
| ---------------------- | -------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------- |
-| `Build on PR` | `build_on_pr.yml` | This workflow is triggered when the label `Run build and Test` is assigned to a PR. It will run all the unit tests in the repository with 4 GPUs. |
+| `Build on PR` | `build_on_pr.yml` | This workflow is triggered when a PR changes essential files and a branch is created/deleted. It will run all the unit tests in the repository with 4 GPUs. |
| `Build on Schedule` | `build_on_schedule.yml` | This workflow will run the unit tests everyday with 8 GPUs. The result is sent to Lark. |
| `Report test coverage` | `report_test_coverage.yml` | This PR will put up a comment to report the test coverage results when `Build` is done. |
+To reduce the average time of the unit test on PR, `Build on PR` workflow manages testmon cache.
+
+1. When creating a new branch, it copies `cache/main/.testmondata*` to `cache//`.
+2. When creating a new PR or change the base branch of a PR, it copies `cache//.testmondata*` to `cache/_pull//`.
+3. When running unit tests for each PR, it restores testmon cache from `cache/_pull//`. After the test, it stores the cache back to `cache/_pull//`.
+4. When a PR is closed, if it's merged, it copies `cache/_pull//.testmondata*` to `cache//`. Otherwise, it just removes `cache/_pull/`.
+5. When a branch is deleted, it removes `cache/[`.
+
### Example Test
| Workflow Name | File name | Description |
@@ -97,7 +105,7 @@ This workflow is triggered by manually dispatching the workflow. It has the foll
| `Synchronize submodule` | `submodule.yml` | This workflow will check if any git submodule is updated. If so, it will create a PR to update the submodule pointers. |
| `Close inactive issues` | `close_inactive.yml` | This workflow will close issues which are stale for 14 days. |
-### Commmunity
+### Community
| Workflow Name | File name | Description |
| -------------------------------------------- | -------------------------------- | -------------------------------------------------------------------------------- |
diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml
index e6febeeb4d87256efc3219b18b9b53875e54aa21..e2114d43bcd0a0da0f36c375d93490f20b9a465f 100644
--- a/.github/workflows/build_on_pr.yml
+++ b/.github/workflows/build_on_pr.yml
@@ -2,22 +2,93 @@ name: Build on PR
on:
pull_request:
- types: [synchronize, labeled]
+ types: [synchronize, opened, reopened, ready_for_review, closed, edited]
+ branches:
+ - "main"
+ - "develop"
+ - "feature/**"
+ paths:
+ - ".github/workflows/build_on_pr.yml" # run command & env variables change
+ - "colossalai/**" # source code change
+ - "!colossalai/**.md" # ignore doc change
+ - "op_builder/**" # cuda extension change
+ - "!op_builder/**.md" # ignore doc change
+ - "requirements/**" # requirements change
+ - "tests/**" # test change
+ - "!tests/**.md" # ignore doc change
+ - "pytest.ini" # test config change
+ - "setup.py" # install command change
+ create:
+ delete:
jobs:
+ prepare_cache:
+ name: Prepare testmon cache
+ if: |
+ github.event_name == 'create' &&
+ github.event.ref_type == 'branch' &&
+ github.event.repository.full_name == 'hpcaitech/ColossalAI'
+ runs-on: [self-hosted, gpu]
+ container:
+ image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
+ options: --rm
+ timeout-minutes: 5
+ defaults:
+ run:
+ shell: bash
+ steps:
+ - name: Copy testmon cache
+ run: | # branch name may contain slash, we need to replace it with space
+ export REF_BRANCH=$(echo ${{ github.event.ref }} | sed "s/\// /")
+ if [ -d /github/home/testmon_cache/${MAIN_BRANCH} ]; then
+ cp -p -r /github/home/testmon_cache/${MAIN_BRANCH} "/github/home/testmon_cache/${REF_BRANCH}"
+ fi
+ env:
+ MAIN_BRANCH: ${{ github.event.master_branch }}
+
+ prepare_cache_for_pr:
+ name: Prepare testmon cache for PR
+ if: |
+ github.event_name == 'pull_request' &&
+ (github.event.action == 'opened' || github.event.action == 'reopened' || (github.event.action == 'edited' && github.event.changes.base != null)) &&
+ github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
+ runs-on: [self-hosted, gpu]
+ container:
+ image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
+ options: --rm
+ timeout-minutes: 5
+ defaults:
+ run:
+ shell: bash
+ concurrency:
+ group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-repare-cache
+ cancel-in-progress: true
+ steps:
+ - name: Copy testmon cache
+ run: | # branch name may contain slash, we need to replace it with space
+ export BASE=$(echo ${{ github.event.pull_request.base.ref }} | sed "s/\// /")
+ if [ -d "/github/home/testmon_cache/${BASE}" ] && [ ! -z "$(ls -A "/github/home/testmon_cache/${BASE}")" ]; then
+ mkdir -p /github/home/testmon_cache/_pull/${PR_NUMBER} && cp -p -r "/github/home/testmon_cache/${BASE}"/.testmondata* /github/home/testmon_cache/_pull/${PR_NUMBER}
+ fi
+ env:
+ PR_NUMBER: ${{ github.event.number }}
+
detect:
name: Detect file change
if: |
+ github.event_name == 'pull_request' &&
+ (github.event.action == 'synchronize' || github.event.action == 'opened' || github.event.action == 'reopened' || github.event.action == 'ready_for_review') &&
github.event.pull_request.draft == false &&
- github.base_ref == 'main' &&
- github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' &&
- contains( github.event.pull_request.labels.*.name, 'Run Build and Test')
+ github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
outputs:
changedExtenisonFiles: ${{ steps.find-extension-change.outputs.all_changed_files }}
anyExtensionFileChanged: ${{ steps.find-extension-change.outputs.any_changed }}
changedLibraryFiles: ${{ steps.find-lib-change.outputs.all_changed_files }}
anyLibraryFileChanged: ${{ steps.find-lib-change.outputs.any_changed }}
runs-on: ubuntu-latest
+ concurrency:
+ group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-detect-change
+ cancel-in-progress: true
steps:
- uses: actions/checkout@v2
with:
@@ -66,14 +137,18 @@ jobs:
build:
name: Build and Test Colossal-AI
needs: detect
+ if: needs.detect.outputs.anyLibraryFileChanged == 'true'
runs-on: [self-hosted, gpu]
container:
- image: hpcaitech/pytorch-cuda:1.11.0-11.3.0
- options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
- timeout-minutes: 40
+ image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
+ options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
+ timeout-minutes: 60
defaults:
run:
shell: bash
+ concurrency:
+ group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-test
+ cancel-in-progress: true
steps:
- name: Checkout TensorNVMe
uses: actions/checkout@v2
@@ -84,7 +159,9 @@ jobs:
- name: Restore TensorNVMe Cache
run: |
- [ ! -z "$(ls -A /github/home/tensornvme_cache/)" ] && cp -p -r /github/home/tensornvme_cache/* /__w/ColossalAI/ColossalAI/TensorNVMe
+ if [ -d /github/home/tensornvme_cache ] && [ ! -z "$(ls -A /github/home/tensornvme_cache/)" ]; then
+ cp -p -r /github/home/tensornvme_cache/* /__w/ColossalAI/ColossalAI/TensorNVMe
+ fi
- name: Install TensorNVMe
run: |
@@ -107,10 +184,11 @@ jobs:
if: needs.detect.outputs.anyExtensionFileChanged != 'true'
run: |
# -p flag is required to preserve the file timestamp to avoid ninja rebuild
- [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -p -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
+ if [ -d /github/home/cuda_ext_cache ] && [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ]; then
+ cp -p -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
+ fi
- name: Install Colossal-AI
- if: needs.detect.outputs.anyLibraryFileChanged == 'true'
run: |
CUDA_EXT=1 pip install -v -e .
pip install -r requirements/requirements-test.txt
@@ -120,14 +198,30 @@ jobs:
# -p flag is required to preserve the file timestamp to avoid ninja rebuild
cp -p -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
+ - name: Restore Testmon Cache
+ run: |
+ if [ -d /github/home/testmon_cache/_pull/${PR_NUMBER} ] && [ ! -z "$(ls -A /github/home/testmon_cache/_pull/${PR_NUMBER})" ]; then
+ cp -p -r /github/home/testmon_cache/_pull/${PR_NUMBER}/.testmondata* /__w/ColossalAI/ColossalAI/
+ fi
+ env:
+ PR_NUMBER: ${{ github.event.number }}
+
- name: Execute Unit Testing
- if: needs.detect.outputs.anyLibraryFileChanged == 'true'
run: |
- CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest --cov=. --cov-report xml tests/
+ CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-forceselect --testmon-cov=. --durations=10 tests/
env:
DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
+ TESTMON_CORE_PKGS: /__w/ColossalAI/ColossalAI/requirements/requirements.txt,/__w/ColossalAI/ColossalAI/requirements/requirements-test.txt
+ LLAMA_PATH: /data/scratch/llama-tiny
+
+ - name: Store Testmon Cache
+ run: |
+ mkdir -p /github/home/testmon_cache/_pull/${PR_NUMBER}
+ cp -p -r /__w/ColossalAI/ColossalAI/.testmondata* /github/home/testmon_cache/_pull/${PR_NUMBER}/
+ env:
+ PR_NUMBER: ${{ github.event.number }}
- name: Collate artifact
env:
@@ -140,7 +234,7 @@ jobs:
echo $PR_NUMBER > ./report/pr_number
# generate coverage.xml if any
- if [ "$anyLibraryFileChanged" == "true" ]; then
+ if [ "$anyLibraryFileChanged" == "true" ] && [ -e .coverage ]; then
allFiles=""
for file in $changedLibraryFiles; do
if [ "$allFiles" == "" ]; then
@@ -165,3 +259,54 @@ jobs:
with:
name: report
path: report/
+
+ store_cache:
+ name: Store testmon cache for PR
+ if: |
+ github.event_name == 'pull_request' &&
+ github.event.action == 'closed' &&
+ github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
+ runs-on: [self-hosted, gpu]
+ container:
+ image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
+ options: --rm
+ timeout-minutes: 5
+ defaults:
+ run:
+ shell: bash
+ steps:
+ - name: Store testmon cache if possible
+ if: github.event.pull_request.merged == true
+ run: | # branch name may contain slash, we need to replace it with space
+ export BASE=$(echo ${{ github.event.pull_request.base.ref }} | sed "s/\// /")
+ if [ -d /github/home/testmon_cache/_pull/${PR_NUMBER} ] && [ ! -z "$(ls -A /github/home/testmon_cache/_pull/${PR_NUMBER})" ]; then
+ cp -p -r /github/home/testmon_cache/_pull/${PR_NUMBER}/.testmondata* "/github/home/testmon_cache/${BASE}/"
+ fi
+ env:
+ PR_NUMBER: ${{ github.event.pull_request.number }}
+
+ - name: Remove testmon cache
+ run: |
+ rm -rf /github/home/testmon_cache/_pull/${PR_NUMBER}
+ env:
+ PR_NUMBER: ${{ github.event.pull_request.number }}
+
+ remove_cache:
+ name: Remove testmon cache
+ if: |
+ github.event_name == 'delete' &&
+ github.event.ref_type == 'branch' &&
+ github.event.repository.full_name == 'hpcaitech/ColossalAI'
+ runs-on: [self-hosted, gpu]
+ container:
+ image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
+ options: --rm
+ timeout-minutes: 5
+ defaults:
+ run:
+ shell: bash
+ steps:
+ - name: Remove testmon cache
+ run: | # branch name may contain slash, we need to replace it with space
+ export BASE=$(echo ${{ github.event.ref }} | sed "s/\// /")
+ rm -rf "/github/home/testmon_cache/${BASE}"
diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml
index 6afdf581e6ca42e078118ef238e3d4199de9aa8f..6c77377be34f993bf409a7332d8cadb367fdd61e 100644
--- a/.github/workflows/build_on_schedule.yml
+++ b/.github/workflows/build_on_schedule.yml
@@ -3,7 +3,7 @@ name: Build on Schedule
on:
schedule:
# run at 00:00 of every Sunday
- - cron: '0 0 * * *'
+ - cron: "0 0 * * *"
workflow_dispatch:
jobs:
@@ -12,8 +12,8 @@ jobs:
if: github.repository == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, 8-gpu]
container:
- image: hpcaitech/pytorch-cuda:1.11.0-11.3.0
- options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
+ image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
+ options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 40
steps:
- name: Check GPU Availability # ensure all GPUs have enough memory
@@ -60,10 +60,11 @@ jobs:
- name: Unit Testing
if: steps.check-avai.outputs.avai == 'true'
run: |
- PYTHONPATH=$PWD pytest tests
+ PYTHONPATH=$PWD pytest --durations=0 tests
env:
DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
+ LLAMA_PATH: /data/scratch/llama-tiny
- name: Notify Lark
id: message-preparation
diff --git a/.github/workflows/compatiblity_test_on_dispatch.yml b/.github/workflows/compatiblity_test_on_dispatch.yml
index 717cf729b3f3fb13d29daa075129c6c653808eb4..5083212993cc40a70bbededf60919a483bdea7ab 100644
--- a/.github/workflows/compatiblity_test_on_dispatch.yml
+++ b/.github/workflows/compatiblity_test_on_dispatch.yml
@@ -19,38 +19,38 @@ jobs:
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- - id: set-matrix
- env:
- TORCH_VERSIONS: ${{ inputs.torch_version }}
- CUDA_VERSIONS: ${{ inputs.cuda_version }}
- run: |
- IFS=','
- DOCKER_IMAGE=()
+ - id: set-matrix
+ env:
+ TORCH_VERSIONS: ${{ inputs.torch_version }}
+ CUDA_VERSIONS: ${{ inputs.cuda_version }}
+ run: |
+ IFS=','
+ DOCKER_IMAGE=()
- for tv in $TORCH_VERSIONS
- do
- for cv in $CUDA_VERSIONS
- do
- DOCKER_IMAGE+=("\"hpcaitech/pytorch-cuda:${tv}-${cv}\"")
- done
- done
+ for tv in $TORCH_VERSIONS
+ do
+ for cv in $CUDA_VERSIONS
+ do
+ DOCKER_IMAGE+=("\"hpcaitech/pytorch-cuda:${tv}-${cv}\"")
+ done
+ done
- container=$( IFS=',' ; echo "${DOCKER_IMAGE[*]}" )
- container="[${container}]"
- echo "$container"
- echo "::set-output name=matrix::{\"container\":$(echo "$container")}"
+ container=$( IFS=',' ; echo "${DOCKER_IMAGE[*]}" )
+ container="[${container}]"
+ echo "$container"
+ echo "::set-output name=matrix::{\"container\":$(echo "$container")}"
build:
name: Test for PyTorch Compatibility
needs: matrix_preparation
if: github.repository == 'hpcaitech/ColossalAI'
- runs-on: [self-hosted, gpu]
+ runs-on: [self-hosted, 8-gpu]
strategy:
fail-fast: false
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
- options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
+ options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 120
steps:
- name: Install dependencies
@@ -64,16 +64,26 @@ jobs:
- name: Install tensornvme
run: |
cd TensorNVMe
- conda install cmake
+ apt update && apt install -y cmake
pip install -r requirements.txt
pip install -v .
- uses: actions/checkout@v2
with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
+ - name: Download cub for CUDA 10.2
+ run: |
+ CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')
+
+ # check if it is CUDA 10.2
+ # download cub
+ if [ "$CUDA_VERSION" = "10.2" ]; then
+ wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
+ unzip 1.8.0.zip
+ cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
+ fi
- name: Install Colossal-AI
run: |
- pip install -r requirements/requirements.txt
- pip install -v --no-cache-dir .
+ CUDA_EXT=1 pip install -v .
pip install -r requirements/requirements-test.txt
- name: Unit Testing
run: |
@@ -82,3 +92,4 @@ jobs:
DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
+ LLAMA_PATH: /data/scratch/llama-tiny
diff --git a/.github/workflows/compatiblity_test_on_pr.yml b/.github/workflows/compatiblity_test_on_pr.yml
index 2fca67b820a1d7cf411f1ee9b854f93a150c639d..cc17c66f9c3afb57f08233bcbc6ca9e1c8e7df5f 100644
--- a/.github/workflows/compatiblity_test_on_pr.yml
+++ b/.github/workflows/compatiblity_test_on_pr.yml
@@ -3,8 +3,8 @@ name: Compatibility Test on PR
on:
pull_request:
paths:
- - 'version.txt'
- - '.compatibility'
+ - "version.txt"
+ - ".compatibility"
jobs:
matrix_preparation:
@@ -12,6 +12,9 @@ jobs:
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
+ concurrency:
+ group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-prepare-matrix
+ cancel-in-progress: true
steps:
- uses: actions/checkout@v3
- id: set-matrix
@@ -32,14 +35,17 @@ jobs:
name: Test for PyTorch Compatibility
needs: matrix_preparation
if: github.repository == 'hpcaitech/ColossalAI'
- runs-on: [self-hosted, gpu]
+ runs-on: [self-hosted, 8-gpu]
strategy:
fail-fast: false
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
- options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
+ options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 120
+ concurrency:
+ group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-test-${{ matrix.container }}
+ cancel-in-progress: true
steps:
- name: Install dependencies
run: |
@@ -52,15 +58,27 @@ jobs:
- name: Install tensornvme
run: |
cd TensorNVMe
- conda install cmake
+ apt update && apt install -y cmake
pip install -r requirements.txt
pip install -v .
- uses: actions/checkout@v2
with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
+ - name: Download cub for CUDA 10.2
+ run: |
+ CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')
+
+ # check if it is CUDA 10.2
+ # download cub
+ if [ "$CUDA_VERSION" = "10.2" ]; then
+ wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
+ unzip 1.8.0.zip
+ cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
+ fi
+
- name: Install Colossal-AI
run: |
- pip install -v --no-cache-dir .
+ CUDA_EXT=1 pip install -v .
pip install -r requirements/requirements-test.txt
- name: Unit Testing
run: |
@@ -69,3 +87,4 @@ jobs:
DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
+ LLAMA_PATH: /data/scratch/llama-tiny
diff --git a/.github/workflows/compatiblity_test_on_schedule.yml b/.github/workflows/compatiblity_test_on_schedule.yml
index 9802795fad246864db5c7a86d9f4273f2ce99ff9..158fe751bf2efd7d8449d7df694351128dbe8100 100644
--- a/.github/workflows/compatiblity_test_on_schedule.yml
+++ b/.github/workflows/compatiblity_test_on_schedule.yml
@@ -32,13 +32,13 @@ jobs:
name: Test for PyTorch Compatibility
needs: matrix_preparation
if: github.repository == 'hpcaitech/ColossalAI'
- runs-on: [self-hosted, gpu]
+ runs-on: [self-hosted, 8-gpu]
strategy:
fail-fast: false
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
- options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
+ options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 120
steps:
- name: Install dependencies
@@ -54,16 +54,28 @@ jobs:
- name: Install tensornvme
run: |
cd TensorNVMe
- conda install cmake
+ apt update && apt install -y cmake
pip install -r requirements.txt
pip install -v .
- uses: actions/checkout@v2
with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
+ - name: Download cub for CUDA 10.2
+ run: |
+ CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')
+
+ # check if it is CUDA 10.2
+ # download cub
+ if [ "$CUDA_VERSION" = "10.2" ]; then
+ wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
+ unzip 1.8.0.zip
+ cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
+ fi
+
- name: Install Colossal-AI
run: |
- pip install -v --no-cache-dir .
+ CUDA_EXT=1 pip install -v .
pip install -r requirements/requirements-test.txt
- name: Unit Testing
@@ -73,6 +85,7 @@ jobs:
DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
+ LLAMA_PATH: /data/scratch/llama-tiny
- name: Notify Lark
id: message-preparation
diff --git a/.github/workflows/cuda_ext_check_before_merge.yml b/.github/workflows/cuda_ext_check_before_merge.yml
index eba5bb98ec07998314708d1b8a159de0d574fcf2..686f0f395c733f6bbab586048195c7fad8db785d 100644
--- a/.github/workflows/cuda_ext_check_before_merge.yml
+++ b/.github/workflows/cuda_ext_check_before_merge.yml
@@ -37,6 +37,18 @@ jobs:
- name: Install PyTorch
run: eval ${{ matrix.build.torch_command }}
+ - name: Download cub for CUDA 10.2
+ run: |
+ CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')
+
+ # check if it is CUDA 10.2
+ # download cub
+ if [ "$CUDA_VERSION" = "10.2" ]; then
+ wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
+ unzip 1.8.0.zip
+ cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
+ fi
+
- name: Build
run: |
CUDA_EXT=1 pip install -v .
diff --git a/.github/workflows/doc_build_after_merge.yml b/.github/workflows/doc_build_after_merge.yml
deleted file mode 100644
index ede04b336620870e3739458e68bfd31afb73ce8c..0000000000000000000000000000000000000000
--- a/.github/workflows/doc_build_after_merge.yml
+++ /dev/null
@@ -1,28 +0,0 @@
-name: Build Documentation After Merge
-
-on:
- workflow_dispatch:
- pull_request:
- paths:
- - 'version.txt'
- - 'docs/**'
- types:
- - closed
-
-jobs:
- build-doc:
- name: Trigger Documentation Build Workflow
- if: ( github.event_name == 'workflow_dispatch' || github.event.pull_request.merged == true ) && github.repository == 'hpcaitech/ColossalAI'
- runs-on: ubuntu-latest
- steps:
- - name: trigger workflow in ColossalAI-Documentation
- run: |
- curl \
- -X POST \
- -H "Accept: application/vnd.github+json" \
- -H "Authorization: Bearer ${GH_TOKEN}"\
- -H "X-GitHub-Api-Version: 2022-11-28" \
- https://api.github.com/repos/hpcaitech/ColossalAI-Documentation/actions/workflows/deploy.yml/dispatches \
- -d '{"ref":"main"}'
- env:
- GH_TOKEN: ${{secrets.DOC_REPO_TOKEN}}
diff --git a/.github/workflows/doc_build_on_schedule_after_release.yml b/.github/workflows/doc_build_on_schedule_after_release.yml
new file mode 100644
index 0000000000000000000000000000000000000000..62dfdc67257c7ac03015652fb0decaab6e04d588
--- /dev/null
+++ b/.github/workflows/doc_build_on_schedule_after_release.yml
@@ -0,0 +1,26 @@
+name: Build Documentation On Schedule & After Release
+
+on:
+ workflow_dispatch:
+ schedule:
+ - cron: "0 12 * * *" # build doc every day at 8pm Singapore time (12pm UTC time)
+ release:
+ types: [published]
+
+jobs:
+ build-doc:
+ name: Trigger Documentation Build Workflow
+ if: github.repository == 'hpcaitech/ColossalAI'
+ runs-on: ubuntu-latest
+ steps:
+ - name: trigger workflow in ColossalAI-Documentation
+ run: |
+ curl \
+ -X POST \
+ -H "Accept: application/vnd.github+json" \
+ -H "Authorization: Bearer ${GH_TOKEN}"\
+ -H "X-GitHub-Api-Version: 2022-11-28" \
+ https://api.github.com/repos/hpcaitech/ColossalAI-Documentation/actions/workflows/deploy.yml/dispatches \
+ -d '{"ref":"main"}'
+ env:
+ GH_TOKEN: ${{secrets.DOC_REPO_TOKEN}}
diff --git a/.github/workflows/doc_check_on_pr.yml b/.github/workflows/doc_check_on_pr.yml
index 2022c957fba837904ba2a2efe678edd1686eece8..ee8a82128dd775a72b8dc71c9326a4c7c922a55a 100644
--- a/.github/workflows/doc_check_on_pr.yml
+++ b/.github/workflows/doc_check_on_pr.yml
@@ -2,57 +2,68 @@ name: Check Documentation on PR
on:
pull_request:
+ branches:
+ - "main"
+ - "develop"
+ - "feature/**"
paths:
- - 'docs/**'
+ - "docs/**"
jobs:
check-i18n:
name: Check docs in diff languages
if: |
- github.event.pull_request.draft == false &&
- github.base_ref == 'main' &&
- github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
+ github.event.pull_request.draft == false &&
+ github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
runs-on: ubuntu-latest
+ concurrency:
+ group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-check-i18n
+ cancel-in-progress: true
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
- python-version: '3.8.14'
+ python-version: "3.8.14"
- run: python .github/workflows/scripts/check_doc_i18n.py -d docs/source
check-doc-build:
name: Test if the docs can be built
if: |
- github.event.pull_request.draft == false &&
- github.base_ref == 'main' &&
- github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
+ github.event.pull_request.draft == false &&
+ github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
runs-on: ubuntu-latest
+ concurrency:
+ group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-check-doc
+ cancel-in-progress: true
steps:
- uses: actions/checkout@v2
with:
- path: './ColossalAI'
+ path: "./ColossalAI"
fetch-depth: 0
- uses: actions/checkout@v2
with:
- path: './ColossalAI-Documentation'
- repository: 'hpcaitech/ColossalAI-Documentation'
+ path: "./ColossalAI-Documentation"
+ repository: "hpcaitech/ColossalAI-Documentation"
- uses: actions/setup-python@v2
with:
- python-version: '3.8.14'
+ python-version: "3.8.14"
# we use the versions in the main branch as the guide for versions to display
# checkout will give your merged branch
# therefore, we need to make the merged branch as the main branch
+ # there is no main branch, so it's safe to checkout the main branch from the merged branch
+ # docer will rebase the remote main branch to the merged branch, so we have to config user
- name: Make the merged branch main
run: |
cd ColossalAI
- curBranch=$(git rev-parse --abbrev-ref HEAD)
- git checkout main
- git merge $curBranch # fast-forward master up to the merge
+ git checkout -b main
+ git branch -u origin/main
+ git config user.name 'github-actions'
+ git config user.email 'github-actions@github.com'
- name: Build docs
run: |
diff --git a/.github/workflows/doc_test_on_pr.yml b/.github/workflows/doc_test_on_pr.yml
index fbe669582c2088c17678a763ceea9bfe78738108..f1e7a2d0cab057573e04755795a7169b3718b8f2 100644
--- a/.github/workflows/doc_test_on_pr.yml
+++ b/.github/workflows/doc_test_on_pr.yml
@@ -1,21 +1,27 @@
name: Test Documentation on PR
on:
pull_request:
+ branches:
+ - "main"
+ - "develop"
+ - "feature/**"
# any change in the examples folder will trigger check for the corresponding example.
paths:
- - 'docs/source/**.md'
+ - "docs/source/**.md"
jobs:
# This is for changed example files detect and output a matrix containing all the corresponding directory name.
detect-changed-doc:
if: |
- github.event.pull_request.draft == false &&
- github.base_ref == 'main' &&
- github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request'
+ github.event.pull_request.draft == false &&
+ github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request'
runs-on: ubuntu-latest
outputs:
any_changed: ${{ steps.changed-files.outputs.any_changed }}
changed_files: ${{ steps.changed-files.outputs.all_changed_files }}
+ concurrency:
+ group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-detect-change
+ cancel-in-progress: true
name: Detect changed example files
steps:
- uses: actions/checkout@v3
@@ -26,10 +32,10 @@ jobs:
- name: Locate base commit
id: locate-base-sha
run: |
- curBranch=$(git rev-parse --abbrev-ref HEAD)
- commonCommit=$(git merge-base origin/main $curBranch)
- echo $commonCommit
- echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT
+ curBranch=$(git rev-parse --abbrev-ref HEAD)
+ commonCommit=$(git merge-base origin/main $curBranch)
+ echo $commonCommit
+ echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT
- name: Get all changed example files
id: changed-files
@@ -43,10 +49,9 @@ jobs:
check-changed-doc:
# Add this condition to avoid executing this job if the trigger event is workflow_dispatch.
if: |
- github.event.pull_request.draft == false &&
- github.base_ref == 'main' &&
- github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request' &&
- needs.detect-changed-doc.outputs.any_changed == 'true'
+ github.event.pull_request.draft == false &&
+ github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request' &&
+ needs.detect-changed-doc.outputs.any_changed == 'true'
name: Test the changed Doc
needs: detect-changed-doc
runs-on: [self-hosted, gpu]
@@ -57,12 +62,15 @@ jobs:
defaults:
run:
shell: bash
+ concurrency:
+ group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-doctest
+ cancel-in-progress: true
steps:
- name: Checkout ColossalAI-Documentation
uses: actions/checkout@v2
with:
- path: './ColossalAI-Documentation'
- repository: 'hpcaitech/ColossalAI-Documentation'
+ path: "./ColossalAI-Documentation"
+ repository: "hpcaitech/ColossalAI-Documentation"
- name: Install Docer
run: |
@@ -81,12 +89,12 @@ jobs:
- name: Install ColossalAI
run: |
source activate pytorch
- pip install -v .
+ CUDA_EXT=1 pip install -v .
- name: Test the Doc
run: |
source activate pytorch
- for file in ${{ steps.changed-files.outputs.all_changed_files }}; do
+ for file in ${{ needs.detect-changed-doc.outputs.changed_files }}; do
echo "Testing $file now..."
docer test -p $file
done
diff --git a/.github/workflows/doc_test_on_schedule.yml b/.github/workflows/doc_test_on_schedule.yml
index 6b4f5d1f908c608a8ba266725c521369a9fb0047..027fbfd0aaeb315d1a013e247bd76a4d2d161ae5 100644
--- a/.github/workflows/doc_test_on_schedule.yml
+++ b/.github/workflows/doc_test_on_schedule.yml
@@ -32,7 +32,7 @@ jobs:
- name: Install ColossalAI
run: |
- pip install -v .
+ CUDA_EXT=1 pip install -v .
- name: Install Doc Test Requirements
run: |
diff --git a/.github/workflows/example_check_on_dispatch.yml b/.github/workflows/example_check_on_dispatch.yml
index 620d4771af55f53972888013c67347f3f4a392bd..9d3bd9a48235033ae07eb955888e5e3761bc4f3a 100644
--- a/.github/workflows/example_check_on_dispatch.yml
+++ b/.github/workflows/example_check_on_dispatch.yml
@@ -53,7 +53,7 @@ jobs:
uses: actions/checkout@v3
- name: Install Colossal-AI
run: |
- pip install -v .
+ CUDA_EXT=1 pip install -v .
- name: Test the example
run: |
dir=${{ matrix.directory }}
diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml
index b22664ee47ccbf8ea786b526964694933a51ab5f..5934704f4102e546d8140c04983e3db58e5da9f8 100644
--- a/.github/workflows/example_check_on_pr.yml
+++ b/.github/workflows/example_check_on_pr.yml
@@ -1,22 +1,28 @@
name: Test Example on PR
on:
pull_request:
+ branches:
+ - "main"
+ - "develop"
+ - "feature/**"
# any change in the examples folder will trigger check for the corresponding example.
paths:
- - 'examples/**'
+ - "examples/**"
jobs:
# This is for changed example files detect and output a matrix containing all the corresponding directory name.
detect-changed-example:
if: |
- github.event.pull_request.draft == false &&
- github.base_ref == 'main' &&
- github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request'
+ github.event.pull_request.draft == false &&
+ github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request'
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.setup-matrix.outputs.matrix }}
anyChanged: ${{ steps.setup-matrix.outputs.anyChanged }}
name: Detect changed example files
+ concurrency:
+ group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-detect-change
+ cancel-in-progress: true
steps:
- uses: actions/checkout@v3
with:
@@ -26,10 +32,10 @@ jobs:
- name: Locate base commit
id: locate-base-sha
run: |
- curBranch=$(git rev-parse --abbrev-ref HEAD)
- commonCommit=$(git merge-base origin/main $curBranch)
- echo $commonCommit
- echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT
+ curBranch=$(git rev-parse --abbrev-ref HEAD)
+ commonCommit=$(git merge-base origin/main $curBranch)
+ echo $commonCommit
+ echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT
- name: Get all changed example files
id: changed-files
@@ -61,10 +67,9 @@ jobs:
check-changed-example:
# Add this condition to avoid executing this job if the trigger event is workflow_dispatch.
if: |
- github.event.pull_request.draft == false &&
- github.base_ref == 'main' &&
- github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request' &&
- needs.detect-changed-example.outputs.anyChanged == 'true'
+ github.event.pull_request.draft == false &&
+ github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event_name == 'pull_request' &&
+ needs.detect-changed-example.outputs.anyChanged == 'true'
name: Test the changed example
needs: detect-changed-example
runs-on: [self-hosted, gpu]
@@ -75,12 +80,15 @@ jobs:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/
timeout-minutes: 10
+ concurrency:
+ group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-example-${{ matrix.directory }}
+ cancel-in-progress: true
steps:
- uses: actions/checkout@v3
- name: Install Colossal-AI
run: |
- pip install -v .
+ CUDA_EXT=1 pip install -v .
- name: Test the example
run: |
diff --git a/.github/workflows/example_check_on_schedule.yml b/.github/workflows/example_check_on_schedule.yml
index bd52ca4321a2b9abd3182df53cf15b77e40766c7..5ed128c3ebc59b834d02e1af9faa832d91527fc4 100644
--- a/.github/workflows/example_check_on_schedule.yml
+++ b/.github/workflows/example_check_on_schedule.yml
@@ -42,7 +42,7 @@ jobs:
- name: Install Colossal-AI
run: |
- pip install -v .
+ CUDA_EXT=1 pip install -v .
- name: Traverse all files
run: |
diff --git a/.github/workflows/release_docker_after_merge.yml b/.github/workflows/release_docker_after_merge.yml
deleted file mode 100644
index 607c19b05472e024d08fdd60dcf95c5516a4420a..0000000000000000000000000000000000000000
--- a/.github/workflows/release_docker_after_merge.yml
+++ /dev/null
@@ -1,75 +0,0 @@
-name: Publish Docker Image to DockerHub after Merge
-
-on:
- workflow_dispatch:
- pull_request:
- paths:
- - 'version.txt'
- types:
- - closed
-
-jobs:
- release:
- name: Publish Docker Image to DockerHub
- if: ( github.event_name == 'workflow_dispatch' || github.event.pull_request.merged == true ) && github.repository == 'hpcaitech/ColossalAI'
- runs-on: [self-hosted, gpu]
- container:
- image: "hpcaitech/docker-in-docker:latest"
- options: --gpus all --rm -v /var/run/docker.sock:/var/run/docker.sock
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Build Docker
- id: build
- run: |
- version=$(cat version.txt)
- tag=hpcaitech/colossalai:$version
- docker build --build-arg http_proxy=http://172.17.0.1:7890 --build-arg https_proxy=http://172.17.0.1:7890 -t $tag ./docker
- echo "tag=${tag}" >> $GITHUB_OUTPUT
-
- - name: Log in to Docker Hub
- uses: docker/login-action@f054a8b539a109f9f41c372932f1ae047eff08c9
- with:
- username: ${{ secrets.DOCKER_USERNAME }}
- password: ${{ secrets.DOCKER_PASSWORD }}
-
- - name: Push Docker image
- id: docker-push
- run: |
- docker push ${{ steps.build.outputs.tag }}
-
- notify:
- name: Notify Lark via webhook
- needs: release
- runs-on: ubuntu-latest
- if: ${{ always() }}
- steps:
- - uses: actions/checkout@v2
-
- - uses: actions/setup-python@v2
- with:
- python-version: '3.8.14'
-
- - name: Install requests
- run: pip install requests
-
- - name: Notify Lark
- id: message-preparation
- run: |
- url=$SERVER_URL/$REPO/actions/runs/$RUN_ID
- if [ "$STATUS" == 'success' ]
- then
- msg="The Docker image for the latest release has been successfully built and pushed to DockerHub."
- else
- msg="Failed to build and push the Docker image for the latest release, please visit $url for details."
- fi
- echo $msg
- python .github/workflows/scripts/send_message_to_lark.py -m "$msg" -u $WEBHOOK_URL
- env:
- SERVER_URL: ${{github.server_url }}
- REPO: ${{ github.repository }}
- RUN_ID: ${{ github.run_id }}
- WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }}
- STATUS: ${{ needs.release.result }}
diff --git a/.github/workflows/release_docker_after_publish.yml b/.github/workflows/release_docker_after_publish.yml
new file mode 100644
index 0000000000000000000000000000000000000000..6c8df9730b0df2a4a5c6622d79c60c83da54879f
--- /dev/null
+++ b/.github/workflows/release_docker_after_publish.yml
@@ -0,0 +1,76 @@
+name: Publish Docker Image to DockerHub after Publish
+
+on:
+ workflow_dispatch:
+ release:
+ types: [published]
+
+jobs:
+ release:
+ name: Publish Docker Image to DockerHub
+ if: github.repository == 'hpcaitech/ColossalAI'
+ runs-on: [self-hosted, gpu]
+ container:
+ image: "hpcaitech/docker-in-docker:latest"
+ options: --gpus all --rm -v /var/run/docker.sock:/var/run/docker.sock
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+
+ - name: Build Docker
+ id: build
+ run: |
+ version=$(cat version.txt)
+ tag=hpcaitech/colossalai:$version
+ latest=hpcaitech/colossalai:latest
+ docker build --build-arg http_proxy=http://172.17.0.1:7890 --build-arg https_proxy=http://172.17.0.1:7890 --build-arg VERSION=v${version} -t $tag ./docker
+ docker tag $tag $latest
+ echo "tag=${tag}" >> $GITHUB_OUTPUT
+ echo "latest=${latest}" >> $GITHUB_OUTPUT
+
+ - name: Log in to Docker Hub
+ uses: docker/login-action@f054a8b539a109f9f41c372932f1ae047eff08c9
+ with:
+ username: ${{ secrets.DOCKER_USERNAME }}
+ password: ${{ secrets.DOCKER_PASSWORD }}
+
+ - name: Push Docker image
+ id: docker-push
+ run: |
+ docker push ${{ steps.build.outputs.tag }}
+ docker push ${{ steps.build.outputs.latest }}
+
+ notify:
+ name: Notify Lark via webhook
+ needs: release
+ runs-on: ubuntu-latest
+ if: ${{ always() }}
+ steps:
+ - uses: actions/checkout@v2
+
+ - uses: actions/setup-python@v2
+ with:
+ python-version: "3.8.14"
+
+ - name: Install requests
+ run: pip install requests
+
+ - name: Notify Lark
+ id: message-preparation
+ run: |
+ url=$SERVER_URL/$REPO/actions/runs/$RUN_ID
+ if [ "$STATUS" == 'success' ]
+ then
+ msg="The Docker image for the latest release has been successfully built and pushed to DockerHub."
+ else
+ msg="Failed to build and push the Docker image for the latest release, please visit $url for details."
+ fi
+ echo $msg
+ python .github/workflows/scripts/send_message_to_lark.py -m "$msg" -u $WEBHOOK_URL
+ env:
+ SERVER_URL: ${{github.server_url }}
+ REPO: ${{ github.repository }}
+ RUN_ID: ${{ github.run_id }}
+ WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }}
+ STATUS: ${{ needs.release.result }}
diff --git a/.github/workflows/report_test_coverage.yml b/.github/workflows/report_test_coverage.yml
index bbada74e685025d15b57f6a4cae855046f99e8ec..c9dc541b8a33fcd29bf6b75d9d70456444c8c421 100644
--- a/.github/workflows/report_test_coverage.yml
+++ b/.github/workflows/report_test_coverage.yml
@@ -9,8 +9,9 @@ on:
jobs:
report-test-coverage:
runs-on: ubuntu-latest
+ if: ${{ github.event.workflow_run.conclusion == 'success' }}
steps:
- - name: 'Download artifact'
+ - name: "Download artifact"
uses: actions/github-script@v6
with:
script: |
@@ -31,7 +32,7 @@ jobs:
let fs = require('fs');
fs.writeFileSync(`${process.env.GITHUB_WORKSPACE}/report.zip`, Buffer.from(download.data));
- - name: 'Unzip artifact'
+ - name: "Unzip artifact"
id: unzip
run: |
unzip report.zip
@@ -58,7 +59,7 @@ jobs:
echo "" >> coverage_report.txt
mv coverage_report.txt coverage.txt
- - name: 'Comment on PR'
+ - name: "Comment on PR"
if: steps.unzip.outputs.hasReport == 'true'
uses: actions/github-script@v6
with:
diff --git a/.github/workflows/run_chatgpt_examples.yml b/.github/workflows/run_chatgpt_examples.yml
index 9d9d3a007851276cbf5dbd69e5c8b776178e5de4..f9e9f400962ef0453a017d4f5cdefcbaf95ff356 100644
--- a/.github/workflows/run_chatgpt_examples.yml
+++ b/.github/workflows/run_chatgpt_examples.yml
@@ -4,11 +4,10 @@ on:
pull_request:
types: [synchronize, opened, reopened]
paths:
- - 'applications/Chat/coati/**'
- - 'applications/Chat/requirements.txt'
- - 'applications/Chat/setup.py'
- - 'applications/Chat/examples/**'
-
+ - "applications/Chat/coati/**"
+ - "applications/Chat/requirements.txt"
+ - "applications/Chat/setup.py"
+ - "applications/Chat/examples/**"
jobs:
tests:
@@ -20,7 +19,7 @@ jobs:
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
- options: --gpus all --rm -v /data/scratch/github_actions/chat:/data/scratch/github_actions/chat
+ options: --gpus all --rm -v /data/scratch/github_actions/chat:/data/scratch/github_actions/chat --shm-size=10.24gb
timeout-minutes: 30
defaults:
run:
@@ -29,28 +28,26 @@ jobs:
- name: Checkout ColossalAI
uses: actions/checkout@v2
- - name: Install ColossalAI and ChatGPT
+ - name: Install ChatGPT
run: |
- pip install -e .
cd applications/Chat
pip install -v .
pip install -r examples/requirements.txt
- name: Install Transformers
run: |
- cd applications/Chat
- git clone https://github.com/hpcaitech/transformers
- cd transformers
- pip install -v .
+ pip install transformers==4.30.2
- name: Execute Examples
run: |
cd applications/Chat
rm -rf ~/.cache/colossalai
- ./examples/test_ci.sh
+ ./tests/test_inference.sh
+ ./tests/test_benchmarks.sh
+ ./tests/test_train.sh
env:
NCCL_SHM_DISABLE: 1
MAX_JOBS: 8
SFT_DATASET: /data/scratch/github_actions/chat/data.json
- PROMPT_PATH: /data/scratch/github_actions/chat/prompts_en.jsonl
+ PROMPT_DATASET: /data/scratch/github_actions/chat/prompts_en.jsonl
PRETRAIN_DATASET: /data/scratch/github_actions/chat/alpaca_data.json
diff --git a/.github/workflows/run_chatgpt_unit_tests.yml b/.github/workflows/run_chatgpt_unit_tests.yml
index 47c80fc9a9fecafa332a0cb0e457052b8d929711..ec5c8ffa319f47d75094d403152c9162cac77c6f 100644
--- a/.github/workflows/run_chatgpt_unit_tests.yml
+++ b/.github/workflows/run_chatgpt_unit_tests.yml
@@ -30,9 +30,8 @@ jobs:
- name: Checkout ColossalAI
uses: actions/checkout@v2
- - name: Install ColossalAI and ChatGPT
+ - name: Install ChatGPT
run: |
- pip install -e .
cd applications/Chat
pip install -v .
pip install -r requirements-test.txt
diff --git a/.github/workflows/scripts/check_doc_i18n.py b/.github/workflows/scripts/check_doc_i18n.py
index 1aa7283e9e52f169d89f337d7942cf55f601257d..1e7f0c33a78598e12a67fa47de192c519d42b3c8 100644
--- a/.github/workflows/scripts/check_doc_i18n.py
+++ b/.github/workflows/scripts/check_doc_i18n.py
@@ -22,13 +22,13 @@ def compare_dirs(dir1, dir2):
# If the corresponding item doesn't exist in the second directory, the directories are different
if not os.path.exists(item_path2):
- print(f'Found mismatch: {item_path1}, {item_path2}')
+ print(f"Found mismatch: {item_path1}, {item_path2}")
return False
# If the corresponding item is a directory, we compare the two directories recursively
if os.path.isdir(item_path1) and os.path.isdir(item_path2):
if not compare_dirs(item_path1, item_path2):
- print(f'Found mismatch: {item_path1}, {item_path2}')
+ print(f"Found mismatch: {item_path1}, {item_path2}")
return False
# both are files
@@ -37,16 +37,16 @@ def compare_dirs(dir1, dir2):
# If the corresponding item is not a file or a directory, the directories are different
else:
- print(f'Found mismatch: {item_path1}, {item_path2}')
+ print(f"Found mismatch: {item_path1}, {item_path2}")
return False
# If all items are the same, the directories are the same
return True
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('-d', '--directory', help="The directory where the multi-language source files are kept.")
+ parser.add_argument("-d", "--directory", help="The directory where the multi-language source files are kept.")
args = parser.parse_args()
i18n_folders = os.listdir(args.directory)
@@ -56,7 +56,7 @@ if __name__ == '__main__':
for i in range(1, len(i18n_folders)):
dir1 = i18n_folders[0]
dir2 = i18n_folders[i]
- print(f'comparing {dir1} vs {dir2}')
+ print(f"comparing {dir1} vs {dir2}")
match = compare_dirs(i18n_folders[0], i18n_folders[i])
if not match:
diff --git a/.github/workflows/scripts/example_checks/check_dispatch_inputs.py b/.github/workflows/scripts/example_checks/check_dispatch_inputs.py
index 5bec96187e0cc5b0aa5ebd8e6a59f73ac8b6d88d..91778f692cc63360afc3a594a61e5a21c387a1a8 100644
--- a/.github/workflows/scripts/example_checks/check_dispatch_inputs.py
+++ b/.github/workflows/scripts/example_checks/check_dispatch_inputs.py
@@ -4,7 +4,7 @@ import os
def check_inputs(input_list):
for path in input_list:
- real_path = os.path.join('examples', path)
+ real_path = os.path.join("examples", path)
if not os.path.exists(real_path):
return False
return True
@@ -12,16 +12,16 @@ def check_inputs(input_list):
def main():
parser = argparse.ArgumentParser()
- parser.add_argument('-f', '--fileNameList', type=str, help="List of file names")
+ parser.add_argument("-f", "--fileNameList", type=str, help="List of file names")
args = parser.parse_args()
name_list = args.fileNameList.split(",")
is_correct = check_inputs(name_list)
if is_correct:
- print('success')
+ print("success")
else:
- print('failure')
+ print("failure")
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/.github/workflows/scripts/example_checks/check_example_weekly.py b/.github/workflows/scripts/example_checks/check_example_weekly.py
index 83eff644e3150dae8fa7ada808dd1e16b571e54a..95a3d24c9a78038095e53a3f1700a8fc8d4d536c 100644
--- a/.github/workflows/scripts/example_checks/check_example_weekly.py
+++ b/.github/workflows/scripts/example_checks/check_example_weekly.py
@@ -17,21 +17,21 @@ def show_files(path, all_files):
def join(input_list, sep=None):
- return (sep or ' ').join(input_list)
+ return (sep or " ").join(input_list)
def main():
- contents = show_files('examples/', [])
+ contents = show_files("examples/", [])
all_loc = []
for file_loc in contents:
- split_loc = file_loc.split('/')
+ split_loc = file_loc.split("/")
# must have two sub-folder levels after examples folder, such as examples/images/vit is acceptable, examples/images/README.md is not, examples/requirements.txt is not.
if len(split_loc) >= 4:
- re_loc = '/'.join(split_loc[1:3])
+ re_loc = "/".join(split_loc[1:3])
if re_loc not in all_loc:
all_loc.append(re_loc)
print(all_loc)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/.github/workflows/scripts/example_checks/detect_changed_example.py b/.github/workflows/scripts/example_checks/detect_changed_example.py
index c69d95a552e96bfe425fc7ad06c0de6a30b1d786..95f671dfb32b5ee007fa6287d3b17dae271fad15 100644
--- a/.github/workflows/scripts/example_checks/detect_changed_example.py
+++ b/.github/workflows/scripts/example_checks/detect_changed_example.py
@@ -3,7 +3,7 @@ import argparse
def main():
parser = argparse.ArgumentParser()
- parser.add_argument('-f', '--fileNameList', type=str, help="The list of changed files")
+ parser.add_argument("-f", "--fileNameList", type=str, help="The list of changed files")
args = parser.parse_args()
name_list = args.fileNameList.split(":")
folder_need_check = set()
@@ -15,10 +15,10 @@ def main():
# - application
# - file
if loc.split("/")[0] == "examples" and len(loc.split("/")) >= 4:
- folder_need_check.add('/'.join(loc.split("/")[1:3]))
+ folder_need_check.add("/".join(loc.split("/")[1:3]))
# Output the result using print. Then the shell can get the values.
print(list(folder_need_check))
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py b/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py
index 16b8957c1d884aba897fd4734a8ec54efc299c5a..412b14c7b28337e3ab0129714b808b0fee28d687 100644
--- a/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py
+++ b/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py
@@ -1,5 +1,4 @@
import os
-from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any, Dict, List
@@ -10,8 +9,7 @@ import seaborn
from requests_toolbelt import MultipartEncoder
-@dataclass
-class Contributor:
+class Counter(dict):
"""
Dataclass for a github contributor.
@@ -19,8 +17,40 @@ class Contributor:
name (str): name of the contributor
num_commits_this_week (int): number of commits made within one week
"""
- name: str
- num_commits_this_week: int
+
+ def record(self, item: str):
+ if item in self:
+ self[item] += 1
+ else:
+ self[item] = 1
+
+ def to_sorted_list(self):
+ data = [(key, value) for key, value in self.items()]
+ data.sort(key=lambda x: x[1], reverse=True)
+ return data
+
+
+def get_utc_time_one_week_ago():
+ """
+ Get the UTC time one week ago.
+ """
+ now = datetime.utcnow()
+ start_datetime = now - timedelta(days=7)
+ return start_datetime
+
+
+def datetime2str(dt):
+ """
+ Convert datetime to string in the format of YYYY-MM-DDTHH:MM:SSZ
+ """
+ return dt.strftime("%Y-%m-%dT%H:%M:%SZ")
+
+
+def str2datetime(string):
+ """
+ Convert string in the format of YYYY-MM-DDTHH:MM:SSZ to datetime
+ """
+ return datetime.strptime(string, "%Y-%m-%dT%H:%M:%SZ")
def plot_bar_chart(x: List[Any], y: List[Any], xlabel: str, ylabel: str, title: str, output_path: str) -> None:
@@ -36,9 +66,30 @@ def plot_bar_chart(x: List[Any], y: List[Any], xlabel: str, ylabel: str, title:
plt.savefig(output_path, dpi=1200)
-def get_issue_pull_request_comments(github_token: str, since: str) -> Dict[str, int]:
+def get_organization_repositories(github_token, organization_name) -> List[str]:
+ """
+ Retrieve the public repositories under the organization.
+ """
+ url = f"https://api.github.com/orgs/{organization_name}/repos?type=public"
+
+ # prepare header
+ headers = {
+ "Authorization": f"Bearer {github_token}",
+ "Accept": "application/vnd.github+json",
+ "X-GitHub-Api-Version": "2022-11-28",
+ }
+
+ res = requests.get(url, headers=headers).json()
+ repo_list = []
+
+ for item in res:
+ repo_list.append(item["name"])
+ return repo_list
+
+
+def get_issue_pull_request_comments(github_token: str, org_name: str, repo_name: str, since: str) -> Dict[str, int]:
"""
- Retrive the issue/PR comments made by our members in the last 7 days.
+ Retrieve the issue/PR comments made by our members in the last 7 days.
Args:
github_token (str): GitHub access token for API calls
@@ -46,9 +97,9 @@ def get_issue_pull_request_comments(github_token: str, since: str) -> Dict[str,
"""
# prepare header
headers = {
- 'Authorization': f'Bearer {github_token}',
- 'Accept': 'application/vnd.github+json',
- 'X-GitHub-Api-Version': '2022-11-28'
+ "Authorization": f"Bearer {github_token}",
+ "Accept": "application/vnd.github+json",
+ "X-GitHub-Api-Version": "2022-11-28",
}
user_engagement_count = {}
@@ -56,28 +107,28 @@ def get_issue_pull_request_comments(github_token: str, since: str) -> Dict[str,
# do pagination to the API
page = 1
while True:
- comment_api = f'https://api.github.com/repos/hpcaitech/ColossalAI/issues/comments?since={since}&page={page}'
+ comment_api = f"https://api.github.com/repos/{org_name}/{repo_name}/issues/comments?since={since}&page={page}"
comment_response = requests.get(comment_api, headers=headers).json()
if len(comment_response) == 0:
break
else:
for item in comment_response:
- comment_author_relationship = item['author_association']
- if comment_author_relationship != 'MEMBER':
+ comment_author_relationship = item["author_association"]
+ if comment_author_relationship != "MEMBER":
# if the comment is not made by our member
# we don't count this comment towards user engagement
continue
- issue_id = item['issue_url'].split('/')[-1]
- issue_api = f'https://api.github.com/repos/hpcaitech/ColossalAI/issues/{issue_id}'
+ issue_id = item["issue_url"].split("/")[-1]
+ issue_api = f"https://api.github.com/repos/{org_name}/{repo_name}/issues/{issue_id}"
issue_response = requests.get(issue_api, headers=headers).json()
- issue_author_relationship = issue_response['author_association']
+ issue_author_relationship = issue_response["author_association"]
- if issue_author_relationship != 'MEMBER':
+ if issue_author_relationship != "MEMBER":
# this means that the issue/PR is not created by our own people
# any comments in this issue/PR by our member will be counted towards the leaderboard
- member_name = item['user']['login']
+ member_name = item["user"]["login"]
if member_name in user_engagement_count:
user_engagement_count[member_name] += 1
@@ -87,9 +138,9 @@ def get_issue_pull_request_comments(github_token: str, since: str) -> Dict[str,
return user_engagement_count
-def get_discussion_comments(github_token, since) -> Dict[str, int]:
+def get_discussion_comments(github_token: str, org_name: str, repo_name: str, since: str) -> Dict[str, int]:
"""
- Retrive the discussion comments made by our members in the last 7 days.
+ Retrieve the discussion comments made by our members in the last 7 days.
This is only available via the GitHub GraphQL API.
Args:
@@ -102,10 +153,10 @@ def get_discussion_comments(github_token, since) -> Dict[str, int]:
if cursor is None:
offset_str = ""
else:
- offset_str = f", after: \"{cursor}\""
+ offset_str = f', after: "{cursor}"'
query = f"""
{{
- repository(owner: "hpcaitech", name: "ColossalAI"){{
+ repository(owner: "{org_name}", name: "{repo_name}"){{
discussions(first: {num} {offset_str}){{
edges {{
cursor
@@ -131,10 +182,10 @@ def get_discussion_comments(github_token, since) -> Dict[str, int]:
if cursor is None:
offset_str = ""
else:
- offset_str = f", before: \"{cursor}\""
+ offset_str = f', before: "{cursor}"'
query = f"""
{{
- repository(owner: "hpcaitech", name: "ColossalAI"){{
+ repository(owner: "{org_name}", name: "{repo_name}"){{
discussion(number: {discussion_number}){{
title
comments(last: {num} {offset_str}){{
@@ -169,8 +220,8 @@ def get_discussion_comments(github_token, since) -> Dict[str, int]:
# a utility function to make call to Github GraphQL API
def _call_graphql_api(query):
headers = {"Authorization": f"Bearer {github_token}"}
- json_data = {'query': query}
- response = requests.post('https://api.github.com/graphql', json=json_data, headers=headers)
+ json_data = {"query": query}
+ response = requests.post("https://api.github.com/graphql", json=json_data, headers=headers)
data = response.json()
return data
@@ -183,21 +234,21 @@ def get_discussion_comments(github_token, since) -> Dict[str, int]:
data = _call_graphql_api(query)
found_discussion_out_of_time_range = False
- edges = data['data']['repository']['discussions']['edges']
+ edges = data["data"]["repository"]["discussions"]["edges"]
if len(edges) == 0:
break
else:
# keep the discussion whose author is not a member
for edge in edges:
# print the discussion title
- discussion = edge['node']
+ discussion = edge["node"]
+ discussion_updated_at = str2datetime(discussion["updatedAt"])
- discussion_updated_at = datetime.strptime(discussion['updatedAt'], "%Y-%m-%dT%H:%M:%SZ")
# check if the updatedAt is within the last 7 days
- # if yes, add it to dicussion_numbers
+ # if yes, add it to discussion_numbers
if discussion_updated_at > since:
- if discussion['authorAssociation'] != 'MEMBER':
- discussion_numbers.append(discussion['number'])
+ if discussion["authorAssociation"] != "MEMBER":
+ discussion_numbers.append(discussion["number"])
else:
found_discussion_out_of_time_range = True
@@ -205,54 +256,55 @@ def get_discussion_comments(github_token, since) -> Dict[str, int]:
break
else:
# update cursor
- cursor = edges[-1]['cursor']
+ cursor = edges[-1]["cursor"]
- # get the dicussion comments and replies made by our member
+ # get the discussion comments and replies made by our member
user_engagement_count = {}
- for dicussion_number in discussion_numbers:
+ for discussion_number in discussion_numbers:
cursor = None
num_per_request = 10
while True:
- query = _generate_comment_reply_count_for_discussion(dicussion_number, num_per_request, cursor)
+ query = _generate_comment_reply_count_for_discussion(discussion_number, num_per_request, cursor)
data = _call_graphql_api(query)
# get the comments
- edges = data['data']['repository']['discussion']['comments']['edges']
+ edges = data["data"]["repository"]["discussion"]["comments"]["edges"]
# update the cursor
if len(edges) == 0:
break
else:
# update cursor for pagination
- cursor = edges[-1]['cursor']
+ cursor = edges[-1]["cursor"]
for edge in edges:
- comment = edge['node']
- if comment['authorAssociation'] == 'MEMBER':
+ comment = edge["node"]
+ if comment["authorAssociation"] == "MEMBER":
# check if the updatedAt is within the last 7 days
# if yes, add it to user_engagement_count
- comment_updated_at = datetime.strptime(comment['updatedAt'], "%Y-%m-%dT%H:%M:%SZ")
+ comment_updated_at = datetime.strptime(comment["updatedAt"], "%Y-%m-%dT%H:%M:%SZ")
if comment_updated_at > since:
- member_name = comment['author']['login']
+ member_name = comment["author"]["login"]
if member_name in user_engagement_count:
user_engagement_count[member_name] += 1
else:
user_engagement_count[member_name] = 1
# get the replies
- reply_edges = comment['replies']['edges']
+ reply_edges = comment["replies"]["edges"]
if len(reply_edges) == 0:
continue
else:
for reply_edge in reply_edges:
- reply = reply_edge['node']
- if reply['authorAssociation'] == 'MEMBER':
+ reply = reply_edge["node"]
+ if reply["authorAssociation"] == "MEMBER":
# check if the updatedAt is within the last 7 days
- # if yes, add it to dicussion_numbers
- reply_updated_at = datetime.strptime(reply['updatedAt'], "%Y-%m-%dT%H:%M:%SZ")
+ # if yes, add it to discussion_numbers
+
+ reply_updated_at = datetime.strptime(reply["updatedAt"], "%Y-%m-%dT%H:%M:%SZ")
if reply_updated_at > since:
- member_name = reply['author']['login']
+ member_name = reply["author"]["login"]
if member_name in user_engagement_count:
user_engagement_count[member_name] += 1
else:
@@ -260,7 +312,9 @@ def get_discussion_comments(github_token, since) -> Dict[str, int]:
return user_engagement_count
-def generate_user_engagement_leaderboard_image(github_token: str, output_path: str) -> bool:
+def generate_user_engagement_leaderboard_image(
+ github_token: str, org_name: str, repo_list: List[str], output_path: str
+) -> bool:
"""
Generate the user engagement leaderboard image for stats within the last 7 days
@@ -270,22 +324,31 @@ def generate_user_engagement_leaderboard_image(github_token: str, output_path: s
"""
# request to the Github API to get the users who have replied the most in the last 7 days
- now = datetime.utcnow()
- start_datetime = now - timedelta(days=7)
- start_datetime_str = start_datetime.strftime("%Y-%m-%dT%H:%M:%SZ")
+ start_datetime = get_utc_time_one_week_ago()
+ start_datetime_str = datetime2str(start_datetime)
# get the issue/PR comments and discussion comment count
- issue_pr_engagement_count = get_issue_pull_request_comments(github_token=github_token, since=start_datetime_str)
- discussion_engagement_count = get_discussion_comments(github_token=github_token, since=start_datetime)
total_engagement_count = {}
- # update the total engagement count
- total_engagement_count.update(issue_pr_engagement_count)
- for name, count in discussion_engagement_count.items():
- if name in total_engagement_count:
- total_engagement_count[name] += count
- else:
- total_engagement_count[name] = count
+ def _update_count(counter):
+ for name, count in counter.items():
+ if name in total_engagement_count:
+ total_engagement_count[name] += count
+ else:
+ total_engagement_count[name] = count
+
+ for repo_name in repo_list:
+ print(f"Fetching user engagement count for {repo_name}/{repo_name}")
+ issue_pr_engagement_count = get_issue_pull_request_comments(
+ github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime_str
+ )
+ discussion_engagement_count = get_discussion_comments(
+ github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime
+ )
+
+ # update the total engagement count
+ _update_count(issue_pr_engagement_count)
+ _update_count(discussion_engagement_count)
# prepare the data for plotting
x = []
@@ -302,20 +365,17 @@ def generate_user_engagement_leaderboard_image(github_token: str, output_path: s
x.append(count)
y.append(name)
- # use Shanghai time to display on the image
- start_datetime_str = datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y-%m-%dT%H:%M:%SZ")
-
# plot the leaderboard
xlabel = f"Number of Comments made (since {start_datetime_str})"
ylabel = "Member"
- title = 'Active User Engagement Leaderboard'
+ title = "Active User Engagement Leaderboard"
plot_bar_chart(x, y, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path)
return True
else:
return False
-def generate_contributor_leaderboard_image(github_token, output_path) -> bool:
+def generate_contributor_leaderboard_image(github_token, org_name, repo_list, output_path) -> bool:
"""
Generate the contributor leaderboard image for stats within the last 7 days
@@ -324,54 +384,81 @@ def generate_contributor_leaderboard_image(github_token, output_path) -> bool:
output_path (str): the path to save the image
"""
# request to the Github API to get the users who have contributed in the last 7 days
- URL = 'https://api.github.com/repos/hpcaitech/ColossalAI/stats/contributors'
headers = {
- 'Authorization': f'Bearer {github_token}',
- 'Accept': 'application/vnd.github+json',
- 'X-GitHub-Api-Version': '2022-11-28'
+ "Authorization": f"Bearer {github_token}",
+ "Accept": "application/vnd.github+json",
+ "X-GitHub-Api-Version": "2022-11-28",
}
- while True:
- response = requests.get(URL, headers=headers).json()
+ counter = Counter()
+ start_datetime = get_utc_time_one_week_ago()
- if len(response) != 0:
- # sometimes the Github API returns empty response for unknown reason
- # request again if the response is empty
- break
+ def _get_url(org_name, repo_name, page):
+ return f"https://api.github.com/repos/{org_name}/{repo_name}/pulls?per_page=50&page={page}&state=closed"
- contributor_list = []
+ def _iterate_by_page(org_name, repo_name):
+ page = 1
+ stop = False
- # get number of commits for each contributor
- start_timestamp = None
- for item in response:
- num_commits_this_week = item['weeks'][-1]['c']
- name = item['author']['login']
- contributor = Contributor(name=name, num_commits_this_week=num_commits_this_week)
- contributor_list.append(contributor)
+ while not stop:
+ print(f"Fetching pull request data for {org_name}/{repo_name} - page{page}")
+ url = _get_url(org_name, repo_name, page)
- # update start_timestamp
- start_timestamp = item['weeks'][-1]['w']
+ while True:
+ response = requests.get(url, headers=headers).json()
+
+ if isinstance(response, list):
+ # sometimes the Github API returns nothing
+ # request again if the response is not a list
+ break
+ print("Empty response, request again...")
+
+ if len(response) == 0:
+ # if the response is empty, stop
+ stop = True
+ break
+
+ # count the pull request and author from response
+ for pr_data in response:
+ merged_at = pr_data["merged_at"]
+ author = pr_data["user"]["login"]
+
+ if merged_at is None:
+ continue
+
+ merge_datetime = str2datetime(merged_at)
+
+ if merge_datetime < start_datetime:
+ # if we found a pull request that is merged before the start_datetime
+ # we stop
+ stop = True
+ break
+ else:
+ # record the author1
+ counter.record(author)
+
+ # next page
+ page += 1
+
+ for repo_name in repo_list:
+ _iterate_by_page(org_name, repo_name)
# convert unix timestamp to Beijing datetime
- start_datetime = datetime.fromtimestamp(start_timestamp, tz=pytz.timezone('Asia/Shanghai'))
- start_datetime_str = start_datetime.strftime("%Y-%m-%dT%H:%M:%SZ")
+ bj_start_datetime = datetime.fromtimestamp(start_datetime.timestamp(), tz=pytz.timezone("Asia/Shanghai"))
+ bj_start_datetime_str = datetime2str(bj_start_datetime)
- # sort by number of commits
- contributor_list.sort(key=lambda x: x.num_commits_this_week, reverse=True)
+ contribution_list = counter.to_sorted_list()
# remove contributors who has zero commits
- contributor_list = [x for x in contributor_list if x.num_commits_this_week > 0]
-
- # prepare the data for plotting
- x = [x.num_commits_this_week for x in contributor_list]
- y = [x.name for x in contributor_list]
+ author_list = [x[0] for x in contribution_list]
+ num_commit_list = [x[1] for x in contribution_list]
# plot
- if len(x) > 0:
- xlabel = f"Number of Commits (since {start_datetime_str})"
+ if len(author_list) > 0:
+ xlabel = f"Number of Pull Requests (since {bj_start_datetime_str})"
ylabel = "Contributor"
- title = 'Active Contributor Leaderboard'
- plot_bar_chart(x, y, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path)
+ title = "Active Contributor Leaderboard"
+ plot_bar_chart(num_commit_list, author_list, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path)
return True
else:
return False
@@ -386,14 +473,14 @@ def upload_image_to_lark(lark_tenant_token: str, image_path: str) -> str:
image_path (str): the path to the image to be uploaded
"""
url = "https://open.feishu.cn/open-apis/im/v1/images"
- form = {'image_type': 'message', 'image': (open(image_path, 'rb'))} # 需要替换具体的path
+ form = {"image_type": "message", "image": (open(image_path, "rb"))} # 需要替换具体的path
multi_form = MultipartEncoder(form)
headers = {
- 'Authorization': f'Bearer {lark_tenant_token}', ## 获取tenant_access_token, 需要替换为实际的token
+ "Authorization": f"Bearer {lark_tenant_token}", ## 获取tenant_access_token, 需要替换为实际的token
}
- headers['Content-Type'] = multi_form.content_type
+ headers["Content-Type"] = multi_form.content_type
response = requests.request("POST", url, headers=headers, data=multi_form).json()
- return response['data']['image_key']
+ return response["data"]["image_key"]
def generate_lark_tenant_access_token(app_id: str, app_secret: str) -> str:
@@ -404,10 +491,10 @@ def generate_lark_tenant_access_token(app_id: str, app_secret: str) -> str:
app_id (str): Lark app id
app_secret (str): Lark app secret
"""
- url = 'https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal'
- data = {'app_id': app_id, 'app_secret': app_secret}
+ url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal"
+ data = {"app_id": app_id, "app_secret": app_secret}
response = requests.post(url, json=data).json()
- return response['tenant_access_token']
+ return response["tenant_access_token"]
def send_image_to_lark(image_key: str, webhook_url: str) -> None:
@@ -434,31 +521,37 @@ def send_message_to_lark(message: str, webhook_url: str):
requests.post(webhook_url, json=data)
-if __name__ == '__main__':
- GITHUB_TOKEN = os.environ['GITHUB_TOKEN']
- CONTRIBUTOR_IMAGE_PATH = 'contributor_leaderboard.png'
- USER_ENGAGEMENT_IMAGE_PATH = 'engagement_leaderboard.png'
+if __name__ == "__main__":
+ GITHUB_TOKEN = os.environ["GITHUB_TOKEN"]
+ CONTRIBUTOR_IMAGE_PATH = "contributor_leaderboard.png"
+ USER_ENGAGEMENT_IMAGE_PATH = "engagement_leaderboard.png"
+ ORG_NAME = "hpcaitech"
+
+ # get all open source repositories
+ REPO_LIST = get_organization_repositories(GITHUB_TOKEN, ORG_NAME)
# generate images
- contrib_success = generate_contributor_leaderboard_image(GITHUB_TOKEN, CONTRIBUTOR_IMAGE_PATH)
- engagement_success = generate_user_engagement_leaderboard_image(GITHUB_TOKEN, USER_ENGAGEMENT_IMAGE_PATH)
+ contrib_success = generate_contributor_leaderboard_image(GITHUB_TOKEN, ORG_NAME, REPO_LIST, CONTRIBUTOR_IMAGE_PATH)
+ engagement_success = generate_user_engagement_leaderboard_image(
+ GITHUB_TOKEN, ORG_NAME, REPO_LIST, USER_ENGAGEMENT_IMAGE_PATH
+ )
# upload images
- APP_ID = os.environ['LARK_APP_ID']
- APP_SECRET = os.environ['LARK_APP_SECRET']
+ APP_ID = os.environ["LARK_APP_ID"]
+ APP_SECRET = os.environ["LARK_APP_SECRET"]
LARK_TENANT_TOKEN = generate_lark_tenant_access_token(app_id=APP_ID, app_secret=APP_SECRET)
contributor_image_key = upload_image_to_lark(LARK_TENANT_TOKEN, CONTRIBUTOR_IMAGE_PATH)
user_engagement_image_key = upload_image_to_lark(LARK_TENANT_TOKEN, USER_ENGAGEMENT_IMAGE_PATH)
# send message to lark
- LARK_WEBHOOK_URL = os.environ['LARK_WEBHOOK_URL']
+ LARK_WEBHOOK_URL = os.environ["LARK_WEBHOOK_URL"]
message = """本周的社区榜单出炉啦!
1. 开发贡献者榜单
2. 用户互动榜单
注:
-- 开发贡献者测评标准为:本周由公司成员提交的commit次数
-- 用户互动榜单测评标准为:本周由公司成员在非成员创建的issue/PR/discussion中回复的次数
+- 开发贡献者测评标准为:本周由公司成员与社区在所有开源仓库提交的Pull Request次数
+- 用户互动榜单测评标准为:本周由公司成员在非成员在所有开源仓库创建的issue/PR/discussion中回复的次数
"""
send_message_to_lark(message, LARK_WEBHOOK_URL)
@@ -467,7 +560,7 @@ if __name__ == '__main__':
if contrib_success:
send_image_to_lark(contributor_image_key, LARK_WEBHOOK_URL)
else:
- send_message_to_lark("本周没有成员贡献commit,无榜单图片生成。", LARK_WEBHOOK_URL)
+ send_message_to_lark("本周没有成员贡献PR,无榜单图片生成。", LARK_WEBHOOK_URL)
# send user engagement image to lark
if engagement_success:
diff --git a/.github/workflows/scripts/generate_release_draft.py b/.github/workflows/scripts/generate_release_draft.py
index dc592e4c977b46b0f0156a9da08832b7624776df..7374481005ef1ba8de1596771a1271cc28425ffa 100644
--- a/.github/workflows/scripts/generate_release_draft.py
+++ b/.github/workflows/scripts/generate_release_draft.py
@@ -7,27 +7,27 @@ import re
import requests
-COMMIT_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/commits'
-TAGS_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/tags'
+COMMIT_API = "https://api.github.com/repos/hpcaitech/ColossalAI/commits"
+TAGS_API = "https://api.github.com/repos/hpcaitech/ColossalAI/tags"
def parse_args():
parser = argparse.ArgumentParser()
- parser.add_argument('--out', type=str, help='output path for the release draft', required=True)
- parser.add_argument('--version', type=str, help='current version to release', required=True)
+ parser.add_argument("--out", type=str, help="output path for the release draft", required=True)
+ parser.add_argument("--version", type=str, help="current version to release", required=True)
return parser.parse_args()
def get_latest_tag_commit(headers=None):
res = requests.get(url=TAGS_API, headers=headers)
data = res.json()
- commit_hash = data[0]['commit']['sha']
- version = data[0]['name']
+ commit_hash = data[0]["commit"]["sha"]
+ version = data[0]["name"]
return commit_hash, version
def get_commit_info(commit_hash, headers=None):
- api = f'{COMMIT_API}/{commit_hash}'
+ api = f"{COMMIT_API}/{commit_hash}"
res = requests.get(url=api, headers=headers)
return res.json()
@@ -37,7 +37,7 @@ def get_all_commit_info(since, headers=None):
results = []
while True:
- api = f'{COMMIT_API}?since={since}&per_page=100&page={page}'
+ api = f"{COMMIT_API}?since={since}&per_page=100&page={page}"
resp = requests.get(url=api, headers=headers)
data = resp.json()
@@ -53,21 +53,21 @@ def get_all_commit_info(since, headers=None):
def collate_release_info(commit_info_list):
results = dict()
- pattern = pattern = r'\[.*\]'
+ pattern = pattern = r"\[.*\]"
for commit_info in commit_info_list:
- author = commit_info['commit']['author']['name']
+ author = commit_info["commit"]["author"]["name"]
try:
- author_url = commit_info['author']['url']
+ author_url = commit_info["author"]["url"]
except:
# author can be None
author_url = None
- msg = commit_info['commit']['message']
+ msg = commit_info["commit"]["message"]
match = re.search(pattern, msg)
if match:
- tag = match.group().lstrip('[').rstrip(']').capitalize()
+ tag = match.group().lstrip("[").rstrip("]").capitalize()
if tag not in results:
results[tag] = []
results[tag].append((msg, author, author_url))
@@ -89,42 +89,43 @@ def generate_release_post_markdown(current_version, last_version, release_info):
for msg, author, author_url in v:
# only keep the first line
- msg = msg.split('\n')[0]
+ msg = msg.split("\n")[0]
if author_url:
- item = f'{msg} by [{author}]({author_url})\n'
+ item = f"{msg} by [{author}]({author_url})\n"
else:
- item = f'{msg} by {author}\n'
- text.append(f'- {item}')
+ item = f"{msg} by {author}\n"
+ text.append(f"- {item}")
- text.append('\n')
+ text.append("\n")
# add full change log
text.append(
- f'**Full Changelog**: https://github.com/hpcaitech/ColossalAI/compare/{current_version}...{last_version}')
+ f"**Full Changelog**: https://github.com/hpcaitech/ColossalAI/compare/{current_version}...{last_version}"
+ )
return text
-if __name__ == '__main__':
+if __name__ == "__main__":
args = parse_args()
- token = os.environ['GITHUB_API_TOKEN']
- headers = {'Authorization': token}
+ token = os.environ["GITHUB_API_TOKEN"]
+ headers = {"Authorization": token}
# get previous release tag
last_release_commit, last_version = get_latest_tag_commit(headers)
last_release_commit_info = get_commit_info(last_release_commit, headers=headers)
- last_release_date = last_release_commit_info['commit']['author']['date']
+ last_release_date = last_release_commit_info["commit"]["author"]["date"]
# get the commits since last release
commit_info = get_all_commit_info(since=last_release_date, headers=headers)
- commit_info = commit_info[:-1] # remove the release commit
+ commit_info = commit_info[:-1] # remove the release commit
# collate into markdown
release_info = collate_release_info(commit_info)
markdown_text = generate_release_post_markdown(args.version, last_version, release_info)
# write into a file
- with open(args.out, 'w') as f:
+ with open(args.out, "w") as f:
for line in markdown_text:
f.write(line)
diff --git a/.github/workflows/scripts/send_message_to_lark.py b/.github/workflows/scripts/send_message_to_lark.py
index a113327a786ed1310b6ef8c0ffc784bd6af2e344..bc005d93c3f5f5c8f431c726a850fee2bc3e9269 100644
--- a/.github/workflows/scripts/send_message_to_lark.py
+++ b/.github/workflows/scripts/send_message_to_lark.py
@@ -5,8 +5,8 @@ import requests
def parse_args():
parser = argparse.ArgumentParser()
- parser.add_argument('-m', '--message', type=str)
- parser.add_argument('-u', '--url', type=str)
+ parser.add_argument("-m", "--message", type=str)
+ parser.add_argument("-u", "--url", type=str)
return parser.parse_args()
@@ -15,6 +15,6 @@ def send_message_to_lark(message, webhook_url):
requests.post(webhook_url, json=data)
-if __name__ == '__main__':
+if __name__ == "__main__":
args = parse_args()
send_message_to_lark(args.message, args.url)
diff --git a/.gitignore b/.gitignore
index bf74a753894fcbbf722812700fdf72a38c9e896c..81113fa99dd570b071e08a9c8cbb681d33626bc9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -155,3 +155,7 @@ colossalai/version.py
# ignore coverage test file
coverage.lcov
coverage.xml
+
+# ignore testmon and coverage files
+.coverage
+.testmondata*
diff --git a/.isort.cfg b/.isort.cfg
index 090aa28e39f32da8c0161d5317c710b6c8781641..ccbf575fdbfacd185cf880431ad81462e0ae8fdf 100644
--- a/.isort.cfg
+++ b/.isort.cfg
@@ -3,3 +3,5 @@ line_length = 120
multi_line_output=3
include_trailing_comma = true
ignore_comments = true
+profile = black
+honor_noqa = true
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 725d266375ef42e69f6097f7d61924213d46b8ff..9871e1184462f2069071ea8db96495b20059d645 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,23 +1,31 @@
repos:
+ - repo: https://github.com/PyCQA/autoflake
+ rev: v2.2.1
+ hooks:
+ - id: autoflake
+ name: autoflake (python)
+ args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports']
+
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
name: sort all imports (python)
- - repo: https://github.com/pre-commit/mirrors-yapf
- rev: v0.32.0
+ - repo: https://github.com/psf/black-pre-commit-mirror
+ rev: 23.9.1
hooks:
- - id: yapf
- name: yapf formatter
- args: ['--style=.style.yapf', '--parallel', '--in-place']
+ - id: black
+ name: black formatter
+ args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v13.0.1
hooks:
- id: clang-format
name: clang formatter
+ types_or: [c++, c]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
diff --git a/.style.yapf b/.style.yapf
deleted file mode 100644
index 05be0dc6a3a598ebd59ff00cec93ce8b809c78b5..0000000000000000000000000000000000000000
--- a/.style.yapf
+++ /dev/null
@@ -1,5 +0,0 @@
-[style]
-based_on_style = google
-spaces_before_comment = 4
-split_before_logical_operator = true
-column_limit = 120
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 00abcf650158c571411f7b78a0ef4c982365b746..a3dc020f74e9529d08a313fc7242b0e552275db6 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -30,6 +30,12 @@ pip install -e .
### Unit Tests
We use [PyTest](https://docs.pytest.org/en/latest/) to execute tests. You can install pytest by `pip install pytest`. As some of the tests require initialization of the distributed backend, GPUs are needed to execute these tests.
+To set up the environment for unit testing, first change your current directory to the root directory of your local ColossalAI repository, then run
+```bash
+pip install -r requirements/requirements-test.txt
+```
+If you encounter an error telling "Could not find a version that satisfies the requirement fbgemm-gpu==0.2.0", please downgrade your python version to 3.8 or 3.9 and try again.
+
If you only want to run CPU tests, you can run
```bash
@@ -138,4 +144,4 @@ You can now create a pull request on the GitHub webpage of your repository. The
Do write clearly the description of your pull request and [link the pull request to your target issue](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue). This will automatically close the issue when the pull request is approved.
-In case of code conflict, you should rebase your branch and resolve the conflicts manually.
\ No newline at end of file
+In case of code conflict, you should rebase your branch and resolve the conflicts manually.
diff --git a/LICENSE b/LICENSE
index c7a5bb16880e6a2a6364a092fe94dda29399b9e3..59d456c5b8a1a452ce5f478f6416d66572bc13c5 100644
--- a/LICENSE
+++ b/LICENSE
@@ -396,3 +396,84 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
+
+ ---------------- LICENSE FOR VLLM TEAM ----------------
+
+ from VLLM TEAM:
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ https://github.com/vllm-project/vllm/blob/main/LICENSE
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+ ---------------- LICENSE FOR LIGHTLLM TEAM ----------------
+
+ from LIGHTLLM TEAM:
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ https://github.com/ModelTC/lightllm/blob/main/LICENSE
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ---------------- LICENSE FOR AutoGPTQ ----------------
+
+ From AutoGPTQ:
+
+ MIT License
+
+ Copyright (c) 2023 潘其威(William)
+
+ Permission is hereby granted, free of charge, to any person obtaining a copy
+ of this software and associated documentation files (the "Software"), to deal
+ in the Software without restriction, including without limitation the rights
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ copies of the Software, and to permit persons to whom the Software is
+ furnished to do so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included in all
+ copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ SOFTWARE.
+
+ ---------------- LICENSE FOR exllama ----------------
+
+ From exllama:
+
+ MIT License
+
+ Permission is hereby granted, free of charge, to any person obtaining a copy
+ of this software and associated documentation files (the "Software"), to deal
+ in the Software without restriction, including without limitation the rights
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ copies of the Software, and to permit persons to whom the Software is
+ furnished to do so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included in all
+ copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ SOFTWARE.
diff --git a/README.md b/README.md
index 79f733122cb3a14ff27b72889caa710ac5e8a5f6..b2efb79104890d58128cce73bb7c334035af6e55 100644
--- a/README.md
+++ b/README.md
@@ -16,7 +16,7 @@
[](https://colossalai.readthedocs.io/en/latest/?badge=latest)
[](https://www.codefactor.io/repository/github/hpcaitech/colossalai)
[](https://huggingface.co/hpcai-tech)
- [](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w)
+ [](https://github.com/hpcaitech/public_assets/tree/main/colossalai/contact/slack)
[](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png)
@@ -25,14 +25,15 @@
## Latest News
+* [2023/09] [One Half-Day of Training Using a Few Hundred Dollars Yields Similar Results to Mainstream Large Models, Open-Source and Commercial-Free Domain-Specific Llm Solution](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution)
+* [2023/09] [70 Billion Parameter LLaMA2 Model Training Accelerated by 195%](https://www.hpc-ai.tech/blog/70b-llama2-training)
+* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth)
+* [2023/07] [65B Model Pretraining Accelerated by 38%, Best Practices for Building LLaMA-Like Base Models Open-Source](https://www.hpc-ai.tech/blog/large-model-pretraining)
* [2023/03] [ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b)
* [2023/03] [Intel and Colossal-AI Partner to Deliver Cost-Efficient Open-Source Solution for Protein Folding Structure Prediction](https://www.hpc-ai.tech/blog/intel-habana)
* [2023/03] [AWS and Google Fund Colossal-AI with Startup Cloud Programs](https://www.hpc-ai.tech/blog/aws-and-google-fund-colossal-ai-with-startup-cloud-programs)
* [2023/02] [Open Source Solution Replicates ChatGPT Training Process! Ready to go with only 1.6GB GPU Memory](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt)
* [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://medium.com/pytorch/latest-colossal-ai-boasts-novel-automatic-parallelism-and-offers-savings-up-to-46x-for-stable-1453b48f3f02)
-* [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper)
-* [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://www.hpc-ai.tech/blog/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding)
-* [2022/09] [HPC-AI Tech Completes $6 Million Seed and Angel Round Fundraising](https://www.hpc-ai.tech/blog/hpc-ai-tech-completes-6-million-seed-and-angel-round-fundraising-led-by-bluerun-ventures-in-the)
## Table of Contents
]
@@ -41,6 +42,7 @@
-
Colossal-AI for Real World Applications
+ - Colossal-LLaMA-2: One Half-Day of Training Using a Few Hundred Dollars Yields Similar Results to Mainstream Large Models, Open-Source and Commercial-Free Domain-Specific Llm Solution
- ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline
- AIGC: Acceleration of Stable Diffusion
- Biomedicine: Acceleration of AlphaFold Protein Structure
@@ -49,6 +51,7 @@
-
Parallel Training Demo
+ - LLaMA 1/2
- GPT-3
- GPT-2
- BERT
@@ -124,15 +127,55 @@ distributed training and inference in a few lines.
## Colossal-AI in the Real World
+### Colossal-LLaMA-2
+
+- One half-day of training using a few hundred dollars yields similar results to mainstream large models, open-source and commercial-free domain-specific LLM solution.
+[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA-2)
+[[blog]](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution)
+[[model weights]](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base)
+
+| | Backbone | Tokens Consumed | | MMLU | CMMLU | AGIEval | GAOKAO | CEval |
+| :----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :-----: | :----: | :----: | :------------------------------: |
+| | | - | | 5-shot | 5-shot | 5-shot | 0-shot | 5-shot |
+| Baichuan-7B | - | 1.2T | | 42.32 (42.30) | 44.53 (44.02) | 38.72 | 36.74 | 42.80 |
+| Baichuan-13B-Base | - | 1.4T | | 50.51 (51.60) | 55.73 (55.30) | 47.20 | 51.41 | 53.60 |
+| Baichuan2-7B-Base | - | 2.6T | | 46.97 (54.16) | 57.67 (57.07) | 45.76 | 52.60 | 54.00 |
+| Baichuan2-13B-Base | - | 2.6T | | 54.84 (59.17) | 62.62 (61.97) | 52.08 | 58.25 | 58.10 |
+| ChatGLM-6B | - | 1.0T | | 39.67 (40.63) | 41.17 (-) | 40.10 | 36.53 | 38.90 |
+| ChatGLM2-6B | - | 1.4T | | 44.74 (45.46) | 49.40 (-) | 46.36 | 45.49 | 51.70 |
+| InternLM-7B | - | 1.6T | | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 |
+| Qwen-7B | - | 2.2T | | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 |
+| | | | | | | | | |
+| Llama-2-7B | - | 2.0T | | 44.47 (45.30) | 32.97 (-) | 32.60 | 25.46 | - |
+| Linly-AI/Chinese-LLaMA-2-7B-hf | Llama-2-7B | 1.0T | | 37.43 | 29.92 | 32.00 | 27.57 | - |
+| wenge-research/yayi-7b-llama2 | Llama-2-7B | - | | 38.56 | 31.52 | 30.99 | 25.95 | - |
+| ziqingyang/chinese-llama-2-7b | Llama-2-7B | - | | 33.86 | 34.69 | 34.52 | 25.18 | 34.2 |
+| TigerResearch/tigerbot-7b-base | Llama-2-7B | 0.3T | | 43.73 | 42.04 | 37.64 | 30.61 | - |
+| LinkSoul/Chinese-Llama-2-7b | Llama-2-7B | - | | 48.41 | 38.31 | 38.45 | 27.72 | - |
+| FlagAlpha/Atom-7B | Llama-2-7B | 0.1T | | 49.96 | 41.10 | 39.83 | 33.00 | - |
+| IDEA-CCNL/Ziya-LLaMA-13B-v1.1 | Llama-13B | 0.11T | | 50.25 | 40.99 | 40.04 | 30.54 | - |
+| | | | | | | | | |
+| **Colossal-LLaMA-2-7b-base** | Llama-2-7B | **0.0085T** | | 53.06 | 49.89 | 51.48 | 58.82 | 50.2 |
+
### ColossalChat
-[ColossalChat](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat): An open-source solution for cloning [ChatGPT](https://openai.com/blog/chatgpt/) with a complete RLHF pipeline. [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat) [[blog]](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b) [[demo]](https://chat.colossalai.org)
+[ColossalChat](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat): An open-source solution for cloning [ChatGPT](https://openai.com/blog/chatgpt/) with a complete RLHF pipeline.
+[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat)
+[[blog]](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b)
+[[demo]](https://www.youtube.com/watch?v=HcTiHzApHm0)
+[[tutorial]](https://www.youtube.com/watch?v=-qFBZFmOJfg)
+
+
+
+
+
+- Up to 10 times faster for RLHF PPO Stage3 Training
@@ -205,6 +248,23 @@ Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/)
(back to top)
## Parallel Training Demo
+### LLaMA2
+
+
+
+
+- 70 billion parameter LLaMA2 model training accelerated by 195%
+[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/llama2)
+[[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training)
+
+### LLaMA1
+
+
+
+
+- 65-billion-parameter large model pretraining accelerated by 38%
+[[code]](https://github.com/hpcaitech/ColossalAI/tree/example/llama/examples/language/llama)
+[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining)
### GPT-3
@@ -352,6 +412,22 @@ If you want to install and enable CUDA kernel fusion (compulsory installation wh
CUDA_EXT=1 pip install .
```
+For Users with CUDA 10.2, you can still build ColossalAI from source. However, you need to manually download the cub library and copy it to the corresponding directory.
+
+```bash
+# clone the repository
+git clone https://github.com/hpcaitech/ColossalAI.git
+cd ColossalAI
+
+# download the cub library
+wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
+unzip 1.8.0.zip
+cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
+
+# install
+CUDA_EXT=1 pip install .
+```
+
(back to top)
## Use Docker
@@ -426,6 +502,7 @@ To cite this project, you can use the following BibTeX citation.
}
```
-Colossal-AI has been accepted as official tutorial by top conferences [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), etc.
+Colossal-AI has been accepted as official tutorial by top conferences [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/),
+[PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), [NVIDIA GTC](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-S51482/) ,etc.
(back to top)
diff --git a/applications/Chat/.gitignore b/applications/Chat/.gitignore
index 2b9b4f345d0fae7bc3872d3c723d2698d201b8b8..5fa068105e261616c1c50f46ffaefdc7aa629b3c 100644
--- a/applications/Chat/.gitignore
+++ b/applications/Chat/.gitignore
@@ -145,4 +145,4 @@ docs/.build
# wandb log
example/wandb/
-examples/awesome-chatgpt-prompts/
\ No newline at end of file
+examples/awesome-chatgpt-prompts/
diff --git a/applications/Chat/README.md b/applications/Chat/README.md
index 9ba831973b6c912a2d8435c382e4c741dd24b6e4..d5be04ab9f44bf34aa87be1e22d36d64227ddd39 100644
--- a/applications/Chat/README.md
+++ b/applications/Chat/README.md
@@ -4,7 +4,6 @@
ColossalChat
-
## Table of Contents
- [Table of Contents](#table-of-contents)
@@ -34,7 +33,9 @@
- [Authors](#authors)
- [Citations](#citations)
- [Licenses](#licenses)
+
---
+
## What is ColossalChat and Coati ?
[ColossalChat](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat) is the project to implement LLM with RLHF, powered by the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) project.
@@ -42,6 +43,7 @@
Coati stands for `ColossalAI Talking Intelligence`. It is the name for the module implemented in this project and is also the name of the large language model developed by the ColossalChat project.
The Coati package provides a unified large language model framework that has implemented the following functions
+
- Supports comprehensive large-model training acceleration capabilities for ColossalAI, without requiring knowledge of complex distributed training algorithms
- Supervised datasets collection
- Supervised instructions fine-tuning
@@ -56,29 +58,42 @@ The Coati package provides a unified large language model framework that has imp
- Image source: https://openai.com/blog/chatgpt
+Image source: https://openai.com/blog/chatgpt
+
**As Colossal-AI is undergoing some major updates, this project will be actively maintained to stay in line with the Colossal-AI project.**
-
More details can be found in the latest news.
-* [2023/03] [ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b)
-* [2023/02] [Open Source Solution Replicates ChatGPT Training Process! Ready to go with only 1.6GB GPU Memory](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt)
+
+- [2023/03] [ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b)
+- [2023/02] [Open Source Solution Replicates ChatGPT Training Process! Ready to go with only 1.6GB GPU Memory](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt)
## Online demo
-You can experience the performance of Coati7B on this page.
-[chat.colossalai.org](https://chat.colossalai.org/)
+
+
+[ColossalChat](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat): An open-source solution for cloning [ChatGPT](https://openai.com/blog/chatgpt/) with a complete RLHF pipeline.
+[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat)
+[[blog]](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b)
+[[demo]](https://www.youtube.com/watch?v=HcTiHzApHm0)
+[[tutorial]](https://www.youtube.com/watch?v=-qFBZFmOJfg)
-Due to resource constraints, we will only provide this service from 29th Mar 2023 to 5 April 2023. However, we have provided the inference code in the [inference](./inference/) folder. The WebUI will be open-sourced soon as well.
+
+
+
+
+> DeepSpeedChat performance comes from its blog on 2023 April 12, ColossalChat performance can be reproduced on an AWS p4d.24xlarge node with 8 A100-40G GPUs with the following command: `torchrun --standalone --nproc_per_node 8 benchmark_opt_lora_dummy.py --num_collect_steps 1 --use_kernels --strategy colossalai_zero2 --experience_batch_size 64 --train_batch_size 32`
-> Warning: Due to model and dataset size limitations, Coati is just a baby model, Coati7B may output incorrect information and lack the ability for multi-turn dialogue. There is still significant room for improvement.
## Install
### Install the environment
-```shell
+```bash
conda create -n coati
conda activate coati
git clone https://github.com/hpcaitech/ColossalAI.git
@@ -87,22 +102,20 @@ pip install .
```
### Install the Transformers
-Given Hugging Face hasn't officially supported the LLaMA models, We fork a branch of Transformers that can be compatible with our code
-```shell
-git clone https://github.com/hpcaitech/transformers
-cd transformers
-pip install .
+```bash
+pip install transformers==4.30.2
```
## How to use?
### Supervised datasets collection
-we collected 104K bilingual datasets of Chinese and English, and you can find the datasets in this repo
-[InstructionWild](https://github.com/XueFuzhao/InstructionWild)
+We collected 104K bilingual datasets of Chinese and English, and you can find the datasets in this repo
+[InstructionWild](https://github.com/XueFuzhao/InstructionWild) and in this [file](https://github.com/XueFuzhao/InstructionWild/blob/main/data/README.md).
Here is how we collected the data
+
@@ -112,12 +125,28 @@ Here is how we collected the data
Stage1 is supervised instructs fine-tuning, which uses the datasets mentioned earlier to fine-tune the model.
You can run the `examples/train_sft.sh` to start a supervised instructs fine-tuning.
+[[Stage1 tutorial video]](https://www.youtube.com/watch?v=-qFBZFmOJfg)
+
+**Note**: the supervised dataset follows the following format,
+
+```json
+[
+ {
+ "instruction": "Provide a list of the top 10 most popular mobile games in Asia",
+ "input": "",
+ "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
+ "id": 0
+ },
+ ...
+]
+```
### RLHF Training Stage2 - Training reward model
Stage2 trains a reward model, which obtains corresponding scores by manually ranking different outputs for the same prompt and supervises the training of the reward model
You can run the `examples/train_rm.sh` to start a reward model training.
+[[Stage2 tutorial video]](https://www.youtube.com/watch?v=gMx2CApKhuo)
### RLHF Training Stage3 - Training model with reinforcement learning by human feedback
@@ -128,6 +157,39 @@ Stage3 uses reinforcement learning algorithm, which is the most complex part of
You can run the `examples/train_prompts.sh` to start training PPO with human feedback.
+[[Stage3 tutorial video]](https://www.youtube.com/watch?v=Z8wwSHxPL9g)
+
+**Note**: the required datasets follow the following format,
+
+- `pretrain dataset`
+
+ ```json
+ [
+ {
+ "instruction": "Provide a list of the top 10 most popular mobile games in Asia",
+ "input": "",
+ "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
+ "id": 0
+ },
+ ...
+ ]
+ ```
+
+- `prompt dataset`
+
+ ```json
+ [
+ {
+ "instruction": "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"",
+ "id": 0
+ },
+ {
+ "instruction": "Write a descriptive paragraph about a memorable vacation you went on",
+ "id": 1
+ },
+ ...
+ ]
+ ```
For more details, see [`examples/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples).
@@ -135,9 +197,9 @@ For more details, see [`examples/`](https://github.com/hpcaitech/ColossalAI/tree
We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.
-We support 8-bit quantization (RTN), 4-bit quantization (GPTQ), and FP16 inference. You can
-Online inference server scripts can help you deploy your own services.
+We support 8-bit quantization (RTN), 4-bit quantization (GPTQ), and FP16 inference.
+Online inference server scripts can help you deploy your own services.
For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
## Coati7B examples
@@ -147,6 +209,7 @@ For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tre
E-mail

+
coding
@@ -180,6 +243,7 @@ For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tre
### Open QA
+
Game

@@ -213,6 +277,7 @@ For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tre
You can find more examples in this [repo](https://github.com/XueFuzhao/InstructionWild/blob/main/comparison.md).
### Limitation
+
Limitation for LLaMA-finetuned models
- Both Alpaca and ColossalChat are based on LLaMA. It is hard to compensate for the missing knowledge in the pre-training stage.
- Lack of counting ability: Cannot count the number of items in a list.
@@ -236,7 +301,7 @@ You can find more examples in this [repo](https://github.com/XueFuzhao/Instructi
We have integrated the Transformers save and load pipeline, allowing users to freely call Hugging Face's language models and save them in the HF format.
-```
+```python
from coati.models.llama import LlamaLM
from coati.trainer import SFTTrainer
@@ -245,20 +310,20 @@ tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
(model, optim) = strategy.prepare((model, optim))
trainer = SFTTrainer(model=model,
- strategy=strategy,
- optim=optim,
- train_dataloader=train_dataloader,
- eval_dataloader=eval_dataloader,
- batch_size=args.batch_size,
- max_epochs=args.max_epochs,
- accumulation_steps = args.accumulation_steps
-)
+ strategy=strategy,
+ optim=optim,
+ train_dataloader=train_dataloader,
+ eval_dataloader=eval_dataloader,
+ batch_size=args.batch_size,
+ max_epochs=args.max_epochs,
+ accumulation_steps=args.accumulation_steps
+ )
trainer.fit()
# this saves in pytorch format
strategy.save_model(model, args.save_path, only_rank0=True)
-# this saves in HF format. ColossalAI strategy with stage-3 doesn't support this method
+# this saves in HF format
strategy.save_pretrained(model, args.save_path, only_rank0=True, tokenizer=tokenizer)
```
@@ -269,12 +334,13 @@ strategy.save_pretrained(model, args.save_path, only_rank0=True, tokenizer=token
Here are some examples that can allow you to train a 7B model on a single or multiple consumer-grade GPUs.
If you only have a single 24G GPU, you can use the following script. `batch_size`, `lora_rank` and `grad_checkpoint` are the most important parameters to successfully train the model.
-```
+
+```bash
+// [INFO]: MAX GPU MEMORY ALLOCATED: 19148.9345703125 MB
torchrun --standalone --nproc_per_node=1 train_sft.py \
--pretrain "/path/to/LLaMa-7B/" \
--model 'llama' \
- --strategy naive \
- --log_interval 10 \
+ --strategy ddp \
--save_path /path/to/Coati-7B \
--dataset /path/to/data.json \
--batch_size 1 \
@@ -287,12 +353,12 @@ torchrun --standalone --nproc_per_node=1 train_sft.py \
```
`colossalai_gemini` strategy can enable a single 24G GPU to train the whole model without using LoRA if you have sufficient CPU memory. You can use the following script.
-```
+
+```bash
torchrun --standalone --nproc_per_node=1 train_sft.py \
--pretrain "/path/to/LLaMa-7B/" \
--model 'llama' \
--strategy colossalai_gemini \
- --log_interval 10 \
--save_path /path/to/Coati-7B \
--dataset /path/to/data.json \
--batch_size 1 \
@@ -304,12 +370,12 @@ torchrun --standalone --nproc_per_node=1 train_sft.py \
```
If you have 4x32 GB GPUs, you can even train the whole 7B model using our `colossalai_zero2_cpu` strategy! The script is given as follows.
-```
+
+```bash
torchrun --standalone --nproc_per_node=4 train_sft.py \
--pretrain "/path/to/LLaMa-7B/" \
--model 'llama' \
--strategy colossalai_zero2_cpu \
- --log_interval 10 \
--save_path /path/to/Coati-7B \
--dataset /path/to/data.json \
--batch_size 1 \
@@ -319,8 +385,8 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \
--max_epochs 1 \
--grad_checkpoint
```
-
+
## The Plan
@@ -335,31 +401,33 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \
- [ ] support chain-of-thought by [langchain](https://github.com/hwchase17/langchain)
### Real-time progress
-You will find our progress in github project broad
-[Coati](https://github.com/orgs/hpcaitech/projects/17/views/1)
+You will find our progress in github [project broad](https://github.com/orgs/hpcaitech/projects/17/views/1).
## Invitation to open-source contribution
+
Referring to the successful attempts of [BLOOM](https://bigscience.huggingface.co/) and [Stable Diffusion](https://en.wikipedia.org/wiki/Stable_Diffusion), any and all developers and partners with computing powers, datasets, models are welcome to join and build the Colossal-AI community, making efforts towards the era of big AI models from the starting point of replicating ChatGPT!
You may contact us or participate in the following ways:
+
1. [Leaving a Star ⭐](https://github.com/hpcaitech/ColossalAI/stargazers) to show your like and support. Thanks!
2. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose), or submitting a PR on GitHub follow the guideline in [Contributing](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md).
3. Join the Colossal-AI community on
-[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w),
-and [WeChat(微信)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas.
+ [Slack](https://github.com/hpcaitech/public_assets/tree/main/colossalai/contact/slack),
+ and [WeChat(微信)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas.
4. Send your official proposal to email contact@hpcaitech.com
Thanks so much to all of our amazing contributors!
## Quick Preview
+
-- An open-source low cost solution for cloning [ChatGPT](https://openai.com/blog/chatgpt/) with a complete RLHF pipeline. [[demo]](https://chat.colossalai.org)
+- An open-source low-cost solution for cloning [ChatGPT](https://openai.com/blog/chatgpt/) with a complete RLHF pipeline. [[demo]](https://chat.colossalai.org)
@@ -386,18 +454,21 @@ Thanks so much to all of our amazing contributors!
| Better Cases | 38 ⚔ **41** | **45** ⚔ 33 |
| Win Rate | 48% ⚔ **52%** | **58%** ⚔ 42% |
| Average Score | 7.06 ⚔ **7.13** | **7.31** ⚔ 6.82 |
+
- Our Coati-7B model performs better than Alpaca-7B when using GPT-4 to evaluate model performance. The Coati-7B model we evaluate is an old version we trained a few weeks ago and the new version is around the corner.
## Authors
Coati is developed by ColossalAI Team:
+
- [Fazzie](https://fazzie-key.cool/about/index.html)
- [FrankLeeeee](https://github.com/FrankLeeeee)
- [BlueRum](https://github.com/ht-zhou)
- [ver217](https://github.com/ver217)
- [ofey404](https://github.com/ofey404)
+- [Wenhao Chen](https://github.com/CWHer)
-The Phd student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project.
+The PhD student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project.
- [Zangwei Zheng](https://github.com/zhengzangw)
- [Xue Fuzhao](https://github.com/XueFuzhao)
diff --git a/applications/Chat/benchmarks/README.md b/applications/Chat/benchmarks/README.md
index bc8ad8ba98165ebfefd842e6deea27120bdb547e..c13f3485863b9dde3044d0ab0fe4f2061544030b 100644
--- a/applications/Chat/benchmarks/README.md
+++ b/applications/Chat/benchmarks/README.md
@@ -27,9 +27,12 @@ We also provide various training strategies:
We only support `torchrun` to launch now. E.g.
-```shell
+```bash
# run OPT-125M with no lora (lora_rank=0) on single-node single-GPU with min batch size
-torchrun --standalone --nproc_per_node 1 benchmark_opt_lora_dummy.py --model 125m --critic_model 125m --strategy ddp --experience_batch_size 1 --train_batch_size 1 --lora_rank 0
+torchrun --standalone --nproc_per_node 1 benchmark_opt_lora_dummy.py \
+ --model 125m --critic_model 125m --strategy ddp \
+ --experience_batch_size 1 --train_batch_size 1 --lora_rank 0
# run Actor (OPT-1.3B) and Critic (OPT-350M) with lora_rank=4 on single-node 4-GPU
-torchrun --standalone --nproc_per_node 4 benchmark_opt_lora_dummy.py --model 1.3b --critic_model 350m --strategy colossalai_zero2 --lora_rank 4
+torchrun --standalone --nproc_per_node 4 benchmark_opt_lora_dummy.py \
+ --model 1.3b --critic_model 350m --strategy colossalai_zero2 --lora_rank 4
```
diff --git a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py
index 7a47624f74d87f188994cf7ee59e5f2d1f2b0b8b..0d0e2a7d34f54642cb712835b60f823a8bff6c8e 100644
--- a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py
+++ b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py
@@ -8,7 +8,7 @@ from coati.models.base import RewardModel
from coati.models.opt import OPTActor, OPTCritic
from coati.trainer import PPOTrainer
from coati.trainer.callbacks import PerformanceEvaluator
-from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
+from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy
from torch.optim import Adam
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
@@ -19,7 +19,7 @@ from colossalai.nn.optimizer import HybridAdam
def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
numel = sum(p.numel() for p in model.parameters())
- if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init:
+ if isinstance(strategy, GeminiStrategy) and strategy.shard_init:
numel *= dist.get_world_size()
return numel
@@ -27,7 +27,7 @@ def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
def preprocess_batch(samples) -> dict:
input_ids = torch.stack(samples)
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
- return {'input_ids': input_ids, 'attention_mask': attention_mask}
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
def print_rank_0(*args, **kwargs) -> None:
@@ -39,32 +39,32 @@ def print_model_numel(model_dict: dict) -> None:
B = 1024**3
M = 1024**2
K = 1024
- outputs = ''
+ outputs = ""
for name, numel in model_dict.items():
- outputs += f'{name}: '
+ outputs += f"{name}: "
if numel >= B:
- outputs += f'{numel / B:.2f} B\n'
+ outputs += f"{numel / B:.2f} B\n"
elif numel >= M:
- outputs += f'{numel / M:.2f} M\n'
+ outputs += f"{numel / M:.2f} M\n"
elif numel >= K:
- outputs += f'{numel / K:.2f} K\n'
+ outputs += f"{numel / K:.2f} K\n"
else:
- outputs += f'{numel}\n'
+ outputs += f"{numel}\n"
print_rank_0(outputs)
def get_gpt_config(model_name: str) -> OPTConfig:
model_map = {
- '125m': OPTConfig.from_pretrained('facebook/opt-125m'),
- '350m': OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16),
- '700m': OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20),
- '1.3b': OPTConfig.from_pretrained('facebook/opt-1.3b'),
- '2.7b': OPTConfig.from_pretrained('facebook/opt-2.7b'),
- '3.5b': OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32),
- '5.5b': OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32),
- '6.7b': OPTConfig.from_pretrained('facebook/opt-6.7b'),
- '10b': OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32),
- '13b': OPTConfig.from_pretrained('facebook/opt-13b'),
+ "125m": OPTConfig.from_pretrained("facebook/opt-125m"),
+ "350m": OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16),
+ "700m": OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20),
+ "1.3b": OPTConfig.from_pretrained("facebook/opt-1.3b"),
+ "2.7b": OPTConfig.from_pretrained("facebook/opt-2.7b"),
+ "3.5b": OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32),
+ "5.5b": OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32),
+ "6.7b": OPTConfig.from_pretrained("facebook/opt-6.7b"),
+ "10b": OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32),
+ "13b": OPTConfig.from_pretrained("facebook/opt-13b"),
}
try:
return model_map[model_name]
@@ -73,20 +73,20 @@ def get_gpt_config(model_name: str) -> OPTConfig:
def main(args):
- if args.strategy == 'ddp':
+ if args.strategy == "ddp":
strategy = DDPStrategy()
- elif args.strategy == 'colossalai_gemini':
- strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
- elif args.strategy == 'colossalai_gemini_cpu':
- strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
- elif args.strategy == 'colossalai_zero2':
- strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
- elif args.strategy == 'colossalai_zero2_cpu':
- strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
- elif args.strategy == 'colossalai_zero1':
- strategy = ColossalAIStrategy(stage=1, placement_policy='cuda')
- elif args.strategy == 'colossalai_zero1_cpu':
- strategy = ColossalAIStrategy(stage=1, placement_policy='cpu')
+ elif args.strategy == "colossalai_gemini":
+ strategy = GeminiStrategy(placement_policy="static",initial_scale=2**5)
+ elif args.strategy == "colossalai_gemini_cpu":
+ strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
+ elif args.strategy == "colossalai_zero2":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
+ elif args.strategy == "colossalai_zero2_cpu":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
+ elif args.strategy == "colossalai_zero1":
+ strategy = LowLevelZeroStrategy(stage=1, placement_policy="cuda")
+ elif args.strategy == "colossalai_zero1_cpu":
+ strategy = LowLevelZeroStrategy(stage=1, placement_policy="cpu")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
@@ -103,92 +103,106 @@ def main(args):
if args.use_kernels:
from coati.kernels import convert_to_xformer_model
- actor, critic, initial_model, reward_model = map(convert_to_xformer_model,
- (actor, critic, initial_model, reward_model))
+
+ actor, critic, initial_model, reward_model = map(
+ convert_to_xformer_model, (actor, critic, initial_model, reward_model)
+ )
actor_numel = get_model_numel(actor, strategy)
critic_numel = get_model_numel(critic, strategy)
initial_model_numel = get_model_numel(initial_model, strategy)
reward_model_numel = get_model_numel(reward_model, strategy)
- print_model_numel({
- 'Actor': actor_numel,
- 'Critic': critic_numel,
- 'Initial model': initial_model_numel,
- 'Reward model': reward_model_numel
- })
- performance_evaluator = PerformanceEvaluator(actor_numel,
- critic_numel,
- initial_model_numel,
- reward_model_numel,
- enable_grad_checkpoint=False,
- ignore_episodes=1)
-
- if args.strategy.startswith('colossalai'):
+ print_model_numel(
+ {
+ "Actor": actor_numel,
+ "Critic": critic_numel,
+ "Initial model": initial_model_numel,
+ "Reward model": reward_model_numel,
+ }
+ )
+ performance_evaluator = PerformanceEvaluator(
+ actor_numel,
+ critic_numel,
+ initial_model_numel,
+ reward_model_numel,
+ enable_grad_checkpoint=False,
+ ignore_episodes=1,
+ )
+
+ if args.strategy.startswith("colossalai"):
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
else:
actor_optim = Adam(actor.parameters(), lr=5e-6)
critic_optim = Adam(critic.parameters(), lr=5e-6)
- tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
tokenizer.pad_token = tokenizer.eos_token
+ tokenizer.padding_side = "left"
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
- trainer = PPOTrainer(strategy,
- actor,
- critic,
- reward_model,
- initial_model,
- actor_optim,
- critic_optim,
- ptx_coef=0,
- max_epochs=args.max_epochs,
- train_batch_size=args.train_batch_size,
- offload_inference_models=args.offload_inference_models,
- max_length=512,
- do_sample=True,
- temperature=1.0,
- top_k=50,
- use_cache=True,
- pad_token_id=tokenizer.pad_token_id,
- eos_token_id=tokenizer.eos_token_id,
- callbacks=[performance_evaluator])
-
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device())
- dataloader = DataLoader(random_prompts,
- batch_size=args.experience_batch_size,
- shuffle=True,
- collate_fn=preprocess_batch)
-
- trainer.fit(dataloader,
- None,
- num_episodes=args.num_episodes,
- max_timesteps=args.max_timesteps,
- update_timesteps=args.update_timesteps)
-
- print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
-
-
-if __name__ == '__main__':
+ dataloader = DataLoader(
+ random_prompts, batch_size=args.experience_batch_size, shuffle=True, collate_fn=preprocess_batch
+ )
+
+ trainer = PPOTrainer(
+ strategy,
+ actor,
+ critic,
+ reward_model,
+ initial_model,
+ actor_optim,
+ critic_optim,
+ tokenizer=tokenizer,
+ ptx_coef=0,
+ train_batch_size=args.train_batch_size,
+ offload_inference_models=args.offload_inference_models,
+ max_length=512,
+ do_sample=True,
+ temperature=1.0,
+ top_k=50,
+ use_cache=True,
+ callbacks=[performance_evaluator],
+ )
+
+ trainer.fit(
+ prompt_dataloader=dataloader,
+ pretrain_dataloader=None,
+ num_episodes=args.num_episodes,
+ num_update_steps=args.num_update_steps,
+ num_collect_steps=args.num_collect_steps,
+ )
+
+ print_rank_0(f"Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB")
+
+
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--model', default='125m')
- parser.add_argument('--critic_model', default='125m')
- parser.add_argument('--strategy',
- choices=[
- 'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2',
- 'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu'
- ],
- default='ddp')
- parser.add_argument('--num_episodes', type=int, default=3)
- parser.add_argument('--max_timesteps', type=int, default=8)
- parser.add_argument('--update_timesteps', type=int, default=8)
- parser.add_argument('--max_epochs', type=int, default=1)
- parser.add_argument('--train_batch_size', type=int, default=8)
- parser.add_argument('--experience_batch_size', type=int, default=8)
- parser.add_argument('--lora_rank', type=int, default=0)
- parser.add_argument('--cuda_mem_frac', type=float, default=1.0)
- parser.add_argument('--offload_inference_models', action='store_true', default=False)
- parser.add_argument('--use_kernels', action='store_true', default=False)
+ parser.add_argument("--model", default="125m")
+ parser.add_argument("--critic_model", default="125m")
+ parser.add_argument(
+ "--strategy",
+ choices=[
+ "ddp",
+ "colossalai_gemini",
+ "colossalai_gemini_cpu",
+ "colossalai_zero2",
+ "colossalai_zero2_cpu",
+ "colossalai_zero1",
+ "colossalai_zero1_cpu",
+ ],
+ default="ddp",
+ )
+ parser.add_argument("--num_episodes", type=int, default=3)
+ parser.add_argument("--num_collect_steps", type=int, default=8)
+ parser.add_argument("--num_update_steps", type=int, default=1)
+ parser.add_argument("--train_batch_size", type=int, default=8)
+ parser.add_argument("--experience_batch_size", type=int, default=8)
+ parser.add_argument("--lora_rank", type=int, default=0)
+ parser.add_argument("--cuda_mem_frac", type=float, default=1.0)
+ parser.add_argument("--offload_inference_models", action="store_true", default=False)
+ parser.add_argument("--use_kernels", action="store_true", default=False)
args = parser.parse_args()
main(args)
diff --git a/applications/Chat/benchmarks/ray/1mmt_dummy.py b/applications/Chat/benchmarks/ray/1mmt_dummy.py
new file mode 100644
index 0000000000000000000000000000000000000000..98ace3869450901f72948d4a5acbbc969c11a18e
--- /dev/null
+++ b/applications/Chat/benchmarks/ray/1mmt_dummy.py
@@ -0,0 +1,192 @@
+import argparse
+import os
+import socket
+from functools import partial
+
+import ray
+import torch
+from coati.quant import llama_load_quant, low_resource_init
+from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
+from coati.ray.experience_maker_holder import ExperienceMakerHolder
+from coati.ray.utils import (
+ get_actor_from_args,
+ get_critic_from_args,
+ get_receivers_per_sender,
+ get_reward_model_from_args,
+ get_strategy_from_args,
+)
+from torch.utils.data import DataLoader
+from transformers import AutoConfig, AutoTokenizer
+from transformers.modeling_utils import no_init_weights
+
+
+def get_free_port():
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.bind(("", 0))
+ return s.getsockname()[1]
+
+
+def get_local_ip():
+ with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
+ s.connect(("8.8.8.8", 80))
+ return s.getsockname()[0]
+
+
+def main(args):
+ master_addr = str(get_local_ip())
+ # trainer_env_info
+ trainer_port = str(get_free_port())
+ env_info_trainers = [
+ {
+ "local_rank": "0",
+ "rank": str(rank),
+ "world_size": str(args.num_trainers),
+ "master_port": trainer_port,
+ "master_addr": master_addr,
+ }
+ for rank in range(args.num_trainers)
+ ]
+
+ # maker_env_info
+ maker_port = str(get_free_port())
+ env_info_maker = {
+ "local_rank": "0",
+ "rank": "0",
+ "world_size": "1",
+ "master_port": maker_port,
+ "master_addr": master_addr,
+ }
+
+ # configure tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
+ tokenizer.pad_token = tokenizer.eos_token
+
+ def model_fn():
+ actor_cfg = AutoConfig.from_pretrained(args.pretrain)
+ critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain)
+ actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
+ critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
+ reward_model = (
+ get_reward_model_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
+ )
+ if args.initial_model_quant_ckpt is not None and args.model == "llama":
+ # quantize initial model
+ with low_resource_init(), no_init_weights():
+ initial_model = get_actor_from_args(args.model, config=actor_cfg)
+ initial_model.model = (
+ llama_load_quant(
+ initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
+ )
+ .cuda()
+ .requires_grad_(False)
+ )
+ else:
+ initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
+ return actor, critic, reward_model, initial_model
+
+ # configure Experience Maker
+ experience_holder_ref = ExperienceMakerHolder.options(name="maker0", num_gpus=1, max_concurrency=2).remote(
+ detached_trainer_name_list=[f"trainer{i}" for i in range(args.num_trainers)],
+ strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
+ model_fn=model_fn,
+ env_info=env_info_maker,
+ kl_coef=0.1,
+ debug=args.debug,
+ # sync_models_from_trainers=True,
+ # generation kwargs:
+ max_length=512,
+ do_sample=True,
+ temperature=1.0,
+ top_k=50,
+ pad_token_id=tokenizer.pad_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ eval_performance=True,
+ use_cache=True,
+ )
+
+ def trainer_model_fn():
+ actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda()
+ critic = (
+ get_critic_from_args(args.critic_model, config=AutoConfig.from_pretrained(args.critic_pretrain))
+ .half()
+ .cuda()
+ )
+ return actor, critic
+
+ # configure Trainer
+ trainer_refs = [
+ DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
+ experience_maker_holder_name_list=[
+ f"maker{x}" for x in get_receivers_per_sender(i, args.num_trainers, 1, allow_idle_sender=True)
+ ],
+ strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
+ model_fn=trainer_model_fn,
+ env_info=env_info_trainer,
+ train_batch_size=args.train_batch_size,
+ buffer_limit=16,
+ eval_performance=True,
+ debug=args.debug,
+ )
+ for i, env_info_trainer in enumerate(env_info_trainers)
+ ]
+
+ dataset_size = args.experience_batch_size * 4
+
+ def data_gen_fn():
+ input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device())
+ attn_mask = torch.ones_like(input_ids)
+ return {"input_ids": input_ids, "attention_mask": attn_mask}
+
+ def build_dataloader(size):
+ dataset = [data_gen_fn() for _ in range(size)]
+ dataloader = DataLoader(dataset, batch_size=args.experience_batch_size)
+ return dataloader
+
+ # uncomment this function if sync_models_from_trainers is True
+ # ray.get([
+ # trainer_ref.sync_models_to_remote_makers.remote()
+ # for trainer_ref in trainer_refs
+ # ])
+
+ wait_tasks = []
+
+ wait_tasks.append(
+ experience_holder_ref.workingloop.remote(
+ partial(build_dataloader, dataset_size), num_steps=args.experience_steps
+ )
+ )
+
+ total_steps = args.experience_batch_size * args.experience_steps // (args.num_trainers * args.train_batch_size)
+ for trainer_ref in trainer_refs:
+ wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
+
+ ray.get(wait_tasks)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--num_trainers", type=int, default=1)
+ parser.add_argument(
+ "--trainer_strategy",
+ choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
+ default="ddp",
+ )
+ parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
+ parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--critic_pretrain", type=str, default=None)
+ parser.add_argument("--experience_steps", type=int, default=4)
+ parser.add_argument("--experience_batch_size", type=int, default=8)
+ parser.add_argument("--train_epochs", type=int, default=1)
+ parser.add_argument("--update_steps", type=int, default=2)
+ parser.add_argument("--train_batch_size", type=int, default=8)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+
+ parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
+ parser.add_argument("--quant_bits", type=int, default=4)
+ parser.add_argument("--quant_group_size", type=int, default=128)
+ parser.add_argument("--debug", action="store_true")
+ args = parser.parse_args()
+ ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
+ main(args)
diff --git a/applications/Chat/benchmarks/ray/mmmt_dummy.py b/applications/Chat/benchmarks/ray/mmmt_dummy.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8860f2979ee0fdf2f677d2e4a80d6350c7d80b8
--- /dev/null
+++ b/applications/Chat/benchmarks/ray/mmmt_dummy.py
@@ -0,0 +1,209 @@
+import argparse
+import os
+import socket
+from functools import partial
+
+import ray
+import torch
+from coati.quant import llama_load_quant, low_resource_init
+from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
+from coati.ray.experience_maker_holder import ExperienceMakerHolder
+from coati.ray.utils import (
+ get_actor_from_args,
+ get_critic_from_args,
+ get_receivers_per_sender,
+ get_reward_model_from_args,
+ get_strategy_from_args,
+)
+from torch.utils.data import DataLoader
+from transformers import AutoConfig, AutoTokenizer
+from transformers.modeling_utils import no_init_weights
+
+
+def get_free_port():
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.bind(("", 0))
+ return s.getsockname()[1]
+
+
+def get_local_ip():
+ with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
+ s.connect(("8.8.8.8", 80))
+ return s.getsockname()[0]
+
+
+def main(args):
+ master_addr = str(get_local_ip())
+ # trainer_env_info
+ trainer_port = str(get_free_port())
+ env_info_trainers = [
+ {
+ "local_rank": "0",
+ "rank": str(rank),
+ "world_size": str(args.num_trainers),
+ "master_port": trainer_port,
+ "master_addr": master_addr,
+ }
+ for rank in range(args.num_trainers)
+ ]
+
+ # maker_env_info
+ maker_port = str(get_free_port())
+ env_info_makers = [
+ {
+ "local_rank": "0",
+ "rank": str(rank),
+ "world_size": str(args.num_makers),
+ "master_port": maker_port,
+ "master_addr": master_addr,
+ }
+ for rank in range(args.num_makers)
+ ]
+
+ # configure tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
+ tokenizer.pad_token = tokenizer.eos_token
+
+ def model_fn():
+ actor_cfg = AutoConfig.from_pretrained(args.pretrain)
+ critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain)
+ actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
+ critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
+ reward_model = (
+ get_reward_model_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
+ )
+ if args.initial_model_quant_ckpt is not None and args.model == "llama":
+ # quantize initial model
+ with low_resource_init(), no_init_weights():
+ initial_model = get_actor_from_args(args.model, config=actor_cfg)
+ initial_model.model = (
+ llama_load_quant(
+ initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
+ )
+ .cuda()
+ .requires_grad_(False)
+ )
+ else:
+ initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
+ return actor, critic, reward_model, initial_model
+
+ # configure Experience Maker
+ experience_holder_refs = [
+ ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote(
+ detached_trainer_name_list=[
+ f"trainer{x}"
+ for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False)
+ ],
+ strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
+ model_fn=model_fn,
+ env_info=env_info_maker,
+ kl_coef=0.1,
+ debug=args.debug,
+ # sync_models_from_trainers=True,
+ # generation kwargs:
+ max_length=512,
+ do_sample=True,
+ temperature=1.0,
+ top_k=50,
+ pad_token_id=tokenizer.pad_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ eval_performance=True,
+ use_cache=True,
+ )
+ for i, env_info_maker in enumerate(env_info_makers)
+ ]
+
+ def trainer_model_fn():
+ actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda()
+ critic = (
+ get_critic_from_args(args.critic_model, config=AutoConfig.from_pretrained(args.critic_pretrain))
+ .half()
+ .cuda()
+ )
+ return actor, critic
+
+ # configure Trainer
+ trainer_refs = [
+ DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
+ experience_maker_holder_name_list=[
+ f"maker{x}"
+ for x in get_receivers_per_sender(i, args.num_trainers, args.num_makers, allow_idle_sender=True)
+ ],
+ strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
+ model_fn=trainer_model_fn,
+ env_info=env_info_trainer,
+ train_batch_size=args.train_batch_size,
+ buffer_limit=16,
+ eval_performance=True,
+ debug=args.debug,
+ )
+ for i, env_info_trainer in enumerate(env_info_trainers)
+ ]
+
+ dataset_size = args.experience_batch_size * 4
+
+ def data_gen_fn():
+ input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device())
+ attn_mask = torch.ones_like(input_ids)
+ return {"input_ids": input_ids, "attention_mask": attn_mask}
+
+ def build_dataloader(size):
+ dataset = [data_gen_fn() for _ in range(size)]
+ dataloader = DataLoader(dataset, batch_size=args.experience_batch_size)
+ return dataloader
+
+ # uncomment this function if sync_models_from_trainers is True
+ # ray.get([
+ # trainer_ref.sync_models_to_remote_makers.remote()
+ # for trainer_ref in trainer_refs
+ # ])
+
+ wait_tasks = []
+
+ for experience_holder_ref in experience_holder_refs:
+ wait_tasks.append(
+ experience_holder_ref.workingloop.remote(
+ partial(build_dataloader, dataset_size), num_steps=args.experience_steps
+ )
+ )
+
+ total_steps = (
+ args.experience_batch_size
+ * args.experience_steps
+ * args.num_makers
+ // (args.num_trainers * args.train_batch_size)
+ )
+ for trainer_ref in trainer_refs:
+ wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
+
+ ray.get(wait_tasks)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--num_makers", type=int, default=1)
+ parser.add_argument("--num_trainers", type=int, default=1)
+ parser.add_argument(
+ "--trainer_strategy",
+ choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
+ default="ddp",
+ )
+ parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
+ parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--critic_pretrain", type=str, default=None)
+ parser.add_argument("--experience_steps", type=int, default=4)
+ parser.add_argument("--experience_batch_size", type=int, default=8)
+ parser.add_argument("--train_epochs", type=int, default=1)
+ parser.add_argument("--update_steps", type=int, default=2)
+ parser.add_argument("--train_batch_size", type=int, default=8)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+
+ parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
+ parser.add_argument("--quant_bits", type=int, default=4)
+ parser.add_argument("--quant_group_size", type=int, default=128)
+ parser.add_argument("--debug", action="store_true")
+ args = parser.parse_args()
+ ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
+ main(args)
diff --git a/applications/Chat/coati/dataset/__init__.py b/applications/Chat/coati/dataset/__init__.py
index f650668e90b0feddbf6b20f1e0fc084a61ac4fc5..599b5760977599aa6c9d5119c647ce277cb929d8 100644
--- a/applications/Chat/coati/dataset/__init__.py
+++ b/applications/Chat/coati/dataset/__init__.py
@@ -1,9 +1,13 @@
from .prompt_dataset import PromptDataset
from .reward_dataset import HhRlhfDataset, RmStaticDataset
-from .sft_dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
+from .sft_dataset import SFTDataset, SupervisedDataset
from .utils import is_rank_0
__all__ = [
- 'RmStaticDataset', 'HhRlhfDataset', 'is_rank_0', 'SFTDataset', 'SupervisedDataset',
- 'DataCollatorForSupervisedDataset', 'PromptDataset'
+ "RmStaticDataset",
+ "HhRlhfDataset",
+ "SFTDataset",
+ "SupervisedDataset",
+ "PromptDataset",
+ "is_rank_0",
]
diff --git a/applications/Chat/coati/dataset/conversation.py b/applications/Chat/coati/dataset/conversation.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2180d96b0d344f4f45e19c1ac85aca3e9bb8691
--- /dev/null
+++ b/applications/Chat/coati/dataset/conversation.py
@@ -0,0 +1,89 @@
+# Copyright 2023 lm-sys@FastChat
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import dataclasses
+from enum import Enum, auto
+from typing import List
+
+
+class SeparatorStyle(Enum):
+ ADD_EOS_TOKEN = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ system: str
+ roles: List[str]
+ messages: List[List[str]]
+ offset: int
+ sep_style: SeparatorStyle = SeparatorStyle.ADD_EOS_TOKEN
+ sep: str = ""
+
+ skip_next: bool = False
+
+ def get_prompt(self):
+ if self.sep_style == SeparatorStyle.ADD_EOS_TOKEN:
+ ret = self.system
+ for role, message in self.messages:
+ if message:
+ ret += role + ": " + message + self.sep
+ else:
+ ret += role + ": "
+ return ret
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ def append_message(self, role, message):
+ self.messages.append([role, message])
+
+ def to_gradio_chatbot(self):
+ ret = []
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
+ if i % 2 == 0:
+ ret.append([msg, None])
+ else:
+ ret[-1][-1] = msg
+ return ret
+
+ def copy(self):
+ return Conversation(
+ system=self.system,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep,
+ )
+
+ def dict(self):
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": self.messages,
+ "offset": self.offset,
+ "sep": self.sep,
+ }
+
+
+conv = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
+ roles=("Human", "Assistant"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.ADD_EOS_TOKEN,
+ sep="",
+)
+
+default_conversation = conv
diff --git a/applications/Chat/coati/dataset/prompt_dataset.py b/applications/Chat/coati/dataset/prompt_dataset.py
index f8ab2346c4b79c31036c0f5904a3344d6729149e..17120e6064b5da4ed0c507f7db01b871c27b3cc4 100644
--- a/applications/Chat/coati/dataset/prompt_dataset.py
+++ b/applications/Chat/coati/dataset/prompt_dataset.py
@@ -1,51 +1,45 @@
-import copy
-import random
from collections import defaultdict
-from dataclasses import dataclass, field
-from typing import Callable, Dict, Sequence
+from typing import Dict
import torch
-import torch.distributed as dist
import transformers
from torch.utils.data import Dataset
-from tqdm import tqdm
from colossalai.logging import get_dist_logger
-from .utils import is_rank_0, jload
-
-logger = get_dist_logger()
+from .utils import jload
class PromptDataset(Dataset):
"""Dataset for supervised fine-tuning."""
- def __init__(self,
- data_path: str,
- tokenizer: transformers.PreTrainedTokenizer,
- max_datasets_size: int = None,
- max_length: int = 96):
+ def __init__(
+ self,
+ data_path: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ max_datasets_size: int = None,
+ max_length: int = 96,
+ ):
super(PromptDataset, self).__init__()
self.keyed_prompt = defaultdict(list)
- logger.info("Loading data...")
+ self.logger = get_dist_logger()
+ self.logger.info("Loading data...")
list_data_dict = jload(data_path)
- logger.info(f"Loaded {len(list_data_dict)} examples.")
+ self.logger.info(f"Loaded {len(list_data_dict)} examples.")
if max_datasets_size is not None:
- logger.info(f"Limiting dataset to {max_datasets_size} examples.")
+ self.logger.info(f"Limiting dataset to {max_datasets_size} examples.")
list_data_dict = list_data_dict[:max_datasets_size]
- for data_dict in list_data_dict:
- token = tokenizer(data_dict["instruction"],
- return_tensors='pt',
- max_length=max_length,
- padding='max_length',
- truncation=True)
- for k, tensor in token.items():
- self.keyed_prompt[k].extend(tensor.to(torch.cuda.current_device()).unbind())
+ instructions = [data_dict["instruction"] for data_dict in list_data_dict]
+ tokens = tokenizer(
+ instructions, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True
+ )
+ for k, tensor in tokens.items():
+ self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind()
def __len__(self):
- return len(self.keyed_prompt)
+ return len(self.keyed_prompt["input_ids"])
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return {k: v[i] for k, v in self.keyed_prompt.items()}
diff --git a/applications/Chat/coati/dataset/reward_dataset.py b/applications/Chat/coati/dataset/reward_dataset.py
index faa1c94d27286a7c69a62c7da699a3377fe14cd9..3afcd7b692380259a77eb2ec5938b2b5ea9b1c83 100644
--- a/applications/Chat/coati/dataset/reward_dataset.py
+++ b/applications/Chat/coati/dataset/reward_dataset.py
@@ -6,7 +6,7 @@ from tqdm import tqdm
from .utils import is_rank_0
-# Dahaos/rm-static
+# Dahoas/rm-static
class RmStaticDataset(Dataset):
"""
Dataset for reward model
@@ -20,44 +20,31 @@ class RmStaticDataset(Dataset):
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__()
- self.chosen = []
- self.reject = []
- if special_token is None:
- self.end_token = tokenizer.eos_token
- else:
- self.end_token = special_token
- for data in tqdm(dataset, disable=not is_rank_0()):
- prompt = data['prompt']
-
- chosen = prompt + data['chosen'] + self.end_token
- chosen_token = tokenizer(chosen,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
- self.chosen.append({
- "input_ids": chosen_token['input_ids'],
- "attention_mask": chosen_token['attention_mask']
- })
-
- reject = prompt + data['rejected'] + self.end_token
- reject_token = tokenizer(reject,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
- self.reject.append({
- "input_ids": reject_token['input_ids'],
- "attention_mask": reject_token['attention_mask']
- })
+ self.end_token = tokenizer.eos_token if special_token is None else special_token
+
+ chosen = [data["prompt"] + data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
+ chosen_token = tokenizer(
+ chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
+
+ reject = [data["prompt"] + data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
+ reject_token = tokenizer(
+ reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
def __len__(self):
- length = len(self.chosen)
+ length = self.chosen["input_ids"].shape[0]
return length
def __getitem__(self, idx):
- return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
- "input_ids"], self.reject[idx]["attention_mask"]
+ return (
+ self.chosen["input_ids"][idx],
+ self.chosen["attention_mask"][idx],
+ self.reject["input_ids"][idx],
+ self.reject["attention_mask"][idx],
+ )
# Anthropic/hh-rlhf
@@ -74,39 +61,28 @@ class HhRlhfDataset(Dataset):
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__()
- self.chosen = []
- self.reject = []
- if special_token is None:
- self.end_token = tokenizer.eos_token
- else:
- self.end_token = special_token
- for data in tqdm(dataset, disable=not is_rank_0()):
- chosen = data['chosen'] + self.end_token
- chosen_token = tokenizer(chosen,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
- self.chosen.append({
- "input_ids": chosen_token['input_ids'],
- "attention_mask": chosen_token['attention_mask']
- })
-
- reject = data['rejected'] + self.end_token
- reject_token = tokenizer(reject,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
- self.reject.append({
- "input_ids": reject_token['input_ids'],
- "attention_mask": reject_token['attention_mask']
- })
+ self.end_token = tokenizer.eos_token if special_token is None else special_token
+
+ chosen = [data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
+ chosen_token = tokenizer(
+ chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
+
+ reject = [data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
+ reject_token = tokenizer(
+ reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
def __len__(self):
- length = len(self.chosen)
+ length = self.chosen["input_ids"].shape[0]
return length
def __getitem__(self, idx):
- return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
- "input_ids"], self.reject[idx]["attention_mask"]
+ return (
+ self.chosen["input_ids"][idx],
+ self.chosen["attention_mask"][idx],
+ self.reject["input_ids"][idx],
+ self.reject["attention_mask"][idx],
+ )
diff --git a/applications/Chat/coati/dataset/sft_dataset.py b/applications/Chat/coati/dataset/sft_dataset.py
index 3e2453468bbc4d5be00e9c2964803299bf176004..c0e257f54a0782711faa2142d0e5f29a2c2a9568 100644
--- a/applications/Chat/coati/dataset/sft_dataset.py
+++ b/applications/Chat/coati/dataset/sft_dataset.py
@@ -13,15 +13,13 @@
# limitations under the License.
import copy
-import random
-from dataclasses import dataclass, field
-from typing import Callable, Dict, Sequence
+from typing import Dict, Optional, Sequence, Tuple
import torch
-import torch.distributed as dist
-import transformers
+from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from torch.utils.data import Dataset
from tqdm import tqdm
+from transformers import PreTrainedTokenizer
from colossalai.logging import get_dist_logger
@@ -31,16 +29,89 @@ logger = get_dist_logger()
IGNORE_INDEX = -100
PROMPT_DICT = {
- "prompt_input":
- ("Below is an instruction that describes a task, paired with an input that provides further context. "
- "Write a response that appropriately completes the request.\n\n"
- "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"),
- "prompt_no_input": ("Below is an instruction that describes a task. "
- "Write a response that appropriately completes the request.\n\n"
- "### Instruction:\n{instruction}\n\n### Response:"),
+ "prompt_input": (
+ "Below is an instruction that describes a task, paired with an input that provides further context. "
+ "Write a response that appropriately completes the request.\n\n"
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
+ ),
+ "prompt_no_input": (
+ "Below is an instruction that describes a task. "
+ "Write a response that appropriately completes the request.\n\n"
+ "### Instruction:\n{instruction}\n\n### Response:"
+ ),
}
+def _preprocess(
+ sources: Sequence[str],
+ targets: Sequence[str],
+ tokenizer: PreTrainedTokenizer,
+ max_length: int,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Preprocess the data by tokenizing."""
+ sequences = [s + t for s, t in zip(sources, targets)]
+ sequences_token = tokenizer(
+ sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ sources_token = tokenizer(
+ sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+
+ assert sequences_token["attention_mask"].dim() == 2, "seq2seq model should be preprocessed differently"
+ labels = copy.deepcopy(sequences_token["input_ids"])
+ for i in range(labels.shape[0]):
+ source_len = sources_token["attention_mask"][i].sum().item()
+ pad_len = max_length - sequences_token["attention_mask"][i].sum().item()
+ if tokenizer.padding_side == "right":
+ # |prompt|completion|eos|pad|
+ labels[i][:source_len] = IGNORE_INDEX
+ labels[i][-pad_len:] = IGNORE_INDEX
+ elif tokenizer.padding_side == "left":
+ # |pad|prompt|completion|eos|
+ labels[i][: pad_len + source_len] = IGNORE_INDEX
+ else:
+ raise RuntimeError()
+
+ return sequences_token["input_ids"], labels, sequences_token["attention_mask"]
+
+
+def _preprocess_chatglm(
+ sources: Sequence[str],
+ targets: Sequence[str],
+ tokenizer: PreTrainedTokenizer,
+ max_length: int,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Preprocess the data by tokenizing.
+ None for attention mask, ChatGLM will calculate attention mask according to input ids
+ """
+
+ labels = []
+ input_ids = []
+ for source, target in zip(sources, targets):
+ source_id = tokenizer.encode(text=source, add_special_tokens=False)
+ target_id = tokenizer.encode(text=target, add_special_tokens=False)
+ input_id = tokenizer.build_inputs_with_special_tokens(source_id, target_id)
+ # truncate
+ sp_token_list = [tokenizer.gmask_token_id, tokenizer.bos_token_id]
+ truncate_length = max(0, len(input_id) - max_length)
+ input_id = input_id[truncate_length:]
+ if truncate_length == len(source_id) + 1:
+ input_id = sp_token_list + input_id[1:]
+ elif truncate_length > len(source_id) + 1:
+ input_id = sp_token_list + input_id[2:]
+
+ context_length = input_id.index(tokenizer.bos_token_id)
+ mask_position = context_length - 1
+ label = [IGNORE_INDEX] * context_length + input_id[mask_position + 1 :]
+
+ pad_len = max_length - len(input_id)
+ input_id = input_id + [tokenizer.pad_token_id] * pad_len
+ input_ids.append(input_id)
+ labels.append(label + [IGNORE_INDEX] * pad_len)
+ return torch.tensor(input_ids), torch.tensor(labels), None
+
+
class SFTDataset(Dataset):
"""
Dataset for sft model
@@ -51,73 +122,45 @@ class SFTDataset(Dataset):
max_length: max length of input
"""
- def __init__(self, dataset, tokenizer: Callable, max_length: int = 512) -> None:
+ def __init__(self, dataset: Dict, tokenizer: PreTrainedTokenizer, max_length: int = 512) -> None:
super().__init__()
self.input_ids = []
- for data in tqdm(dataset, disable=not is_rank_0()):
- prompt = data['prompt'] + data['completion'] + tokenizer.eos_token
- prompt_token = tokenizer(prompt,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
+ sources = [data["prompt"] for data in dataset]
+ targets = [data["completion"] + tokenizer.eos_token for data in tqdm(dataset, disable=not is_rank_0())]
- self.input_ids.append(prompt_token['input_ids'][0])
- self.labels = copy.deepcopy(self.input_ids)
+ logger.info("Tokenizing inputs... This may take some time...")
+ if isinstance(tokenizer, ChatGLMTokenizer):
+ self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
+ sources, targets, tokenizer, max_length
+ )
+ else:
+ self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
+
+ logger.info("Loaded dataset.")
def __len__(self):
- length = len(self.input_ids)
+ length = self.input_ids.shape[0]
return length
def __getitem__(self, idx):
- return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
-
-
-def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> Dict:
- """Tokenize a list of strings."""
- tokenized_list = [
- tokenizer(
- text,
- return_tensors="pt",
- padding="longest",
- max_length=max_length,
- truncation=True,
- ) for text in strings
- ]
- input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
- input_ids_lens = labels_lens = [
- tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
- ]
- return dict(
- input_ids=input_ids,
- labels=labels,
- input_ids_lens=input_ids_lens,
- labels_lens=labels_lens,
- )
-
-
-def preprocess(
- sources: Sequence[str],
- targets: Sequence[str],
- tokenizer: transformers.PreTrainedTokenizer,
- max_length: int,
-) -> Dict:
- """Preprocess the data by tokenizing."""
- examples = [s + t for s, t in zip(sources, targets)]
- examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)]
- input_ids = examples_tokenized["input_ids"]
- labels = copy.deepcopy(input_ids)
- for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
- label[:source_len] = IGNORE_INDEX
- return dict(input_ids=input_ids, labels=labels)
+ if self.attention_mask is not None:
+ return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
+ else:
+ return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
- def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512):
- super(SupervisedDataset, self).__init__()
+ def __init__(
+ self,
+ data_path: str,
+ tokenizer: PreTrainedTokenizer,
+ max_datasets_size: Optional[int] = None,
+ max_length: int = 512,
+ ):
+ super().__init__()
logger.info("Loading data...")
list_data_dict = jload(data_path)
logger.info(f"Loaded {len(list_data_dict)} examples.")
@@ -129,38 +172,27 @@ class SupervisedDataset(Dataset):
logger.info("Formatting inputs...")
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
sources = [
- prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
+ prompt_input.format_map(example) if "input" in example else prompt_no_input.format_map(example)
for example in list_data_dict
]
- targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
+ targets = [example["output"] + tokenizer.eos_token for example in list_data_dict]
logger.info("Tokenizing inputs... This may take some time...")
- data_dict = preprocess(sources, targets, tokenizer, max_length)
+ if isinstance(tokenizer, ChatGLMTokenizer):
+ self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
+ sources, targets, tokenizer, max_length
+ )
+ else:
+ self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
- self.input_ids = data_dict["input_ids"]
- self.labels = data_dict["labels"]
+ logger.info("Loaded dataset.")
def __len__(self):
- return len(self.input_ids)
-
- def __getitem__(self, i) -> Dict[str, torch.Tensor]:
- return dict(input_ids=self.input_ids[i], labels=self.labels[i])
-
-
-@dataclass
-class DataCollatorForSupervisedDataset(object):
- """Collate examples for supervised fine-tuning."""
-
- tokenizer: transformers.PreTrainedTokenizer
+ length = self.input_ids.shape[0]
+ return length
- def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
- input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
- input_ids = torch.nn.utils.rnn.pad_sequence(input_ids,
- batch_first=True,
- padding_value=self.tokenizer.pad_token_id)
- labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
- return dict(
- input_ids=input_ids,
- labels=labels,
- attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
- )
+ def __getitem__(self, idx):
+ if self.attention_mask is not None:
+ return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
+ else:
+ return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
diff --git a/applications/Chat/coati/experience_buffer/__init__.py b/applications/Chat/coati/experience_buffer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2a48d0a3b205a771ab6a3afddfd92e9324ba844
--- /dev/null
+++ b/applications/Chat/coati/experience_buffer/__init__.py
@@ -0,0 +1,4 @@
+from .base import ExperienceBuffer
+from .naive import NaiveExperienceBuffer
+
+__all__ = ["ExperienceBuffer", "NaiveExperienceBuffer"]
diff --git a/applications/Chat/coati/experience_buffer/base.py b/applications/Chat/coati/experience_buffer/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..7047785308f31f5835c7d76bc7d5d04265523e13
--- /dev/null
+++ b/applications/Chat/coati/experience_buffer/base.py
@@ -0,0 +1,43 @@
+from abc import ABC, abstractmethod
+from typing import Any
+
+from coati.experience_maker.base import Experience
+
+
+class ExperienceBuffer(ABC):
+ """Experience buffer base class. It stores experience.
+
+ Args:
+ sample_batch_size (int): Batch size when sampling.
+ limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
+ """
+
+ def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
+ super().__init__()
+ self.sample_batch_size = sample_batch_size
+ # limit <= 0 means unlimited
+ self.limit = limit
+
+ @abstractmethod
+ def append(self, experience: Experience) -> None:
+ pass
+
+ @abstractmethod
+ def clear(self) -> None:
+ pass
+
+ @abstractmethod
+ def sample(self) -> Experience:
+ pass
+
+ @abstractmethod
+ def __len__(self) -> int:
+ pass
+
+ @abstractmethod
+ def __getitem__(self, idx: int) -> Any:
+ pass
+
+ @abstractmethod
+ def collate_fn(self, batch: Any) -> Experience:
+ pass
diff --git a/applications/Chat/coati/experience_buffer/naive.py b/applications/Chat/coati/experience_buffer/naive.py
new file mode 100644
index 0000000000000000000000000000000000000000..d47b67dbe7131df8ff4ba5202cfd8dbcf70eb1ad
--- /dev/null
+++ b/applications/Chat/coati/experience_buffer/naive.py
@@ -0,0 +1,60 @@
+import random
+import warnings
+from typing import List
+
+import torch
+from coati.experience_maker.base import Experience
+
+from .base import ExperienceBuffer
+from .utils import BufferItem, make_experience_batch, split_experience_batch
+
+
+class NaiveExperienceBuffer(ExperienceBuffer):
+ """Naive experience buffer class. It stores experience.
+
+ Args:
+ sample_batch_size (int): Batch size when sampling.
+ limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
+ cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True.
+ """
+
+ def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True) -> None:
+ super().__init__(sample_batch_size, limit)
+ self.cpu_offload = cpu_offload
+ self.target_device = torch.device(f"cuda:{torch.cuda.current_device()}")
+ # TODO(ver217): add prefetch
+ self.items: List[BufferItem] = []
+
+ @torch.no_grad()
+ def append(self, experience: Experience) -> None:
+ if self.cpu_offload:
+ experience.to_device(torch.device("cpu"))
+ items = split_experience_batch(experience)
+ self.items.extend(items)
+
+ if self.limit > 0:
+ samples_to_remove = len(self.items) - self.limit
+ if samples_to_remove > 0:
+ warnings.warn(f"Experience buffer is full. Removing {samples_to_remove} samples.")
+ self.items = self.items[samples_to_remove:]
+
+ def clear(self) -> None:
+ self.items.clear()
+
+ @torch.no_grad()
+ def sample(self) -> Experience:
+ items = random.sample(self.items, self.sample_batch_size)
+ experience = make_experience_batch(items)
+ if self.cpu_offload:
+ experience.to_device(self.target_device)
+ return experience
+
+ def __len__(self) -> int:
+ return len(self.items)
+
+ def __getitem__(self, idx: int) -> BufferItem:
+ return self.items[idx]
+
+ def collate_fn(self, batch) -> Experience:
+ experience = make_experience_batch(batch)
+ return experience
diff --git a/applications/Chat/coati/experience_buffer/utils.py b/applications/Chat/coati/experience_buffer/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..baedbebd184f7ee5dcf335ded07ac6b50b3cec63
--- /dev/null
+++ b/applications/Chat/coati/experience_buffer/utils.py
@@ -0,0 +1,74 @@
+from dataclasses import dataclass
+from typing import List, Optional
+
+import torch
+import torch.nn.functional as F
+from coati.experience_maker.base import Experience
+
+
+@dataclass
+class BufferItem:
+ """BufferItem is an item of experience data.
+
+ Shapes of each tensor:
+ sequences: (S)
+ action_log_probs: (A)
+ values: (1)
+ reward: (1)
+ advantages: (1)
+ attention_mask: (S)
+ action_mask: (A)
+
+ "A" is the number of actions.
+ """
+
+ sequences: torch.Tensor
+ action_log_probs: torch.Tensor
+ values: torch.Tensor
+ reward: torch.Tensor
+ advantages: torch.Tensor
+ attention_mask: Optional[torch.LongTensor]
+ action_mask: Optional[torch.BoolTensor]
+
+
+def split_experience_batch(experience: Experience) -> List[BufferItem]:
+ batch_size = experience.sequences.size(0)
+ batch_kwargs = [{} for _ in range(batch_size)]
+ keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask")
+ for key in keys:
+ value = getattr(experience, key)
+ if isinstance(value, torch.Tensor):
+ vals = torch.unbind(value)
+ else:
+ # None
+ vals = [value for _ in range(batch_size)]
+ assert batch_size == len(vals)
+ for i, v in enumerate(vals):
+ batch_kwargs[i][key] = v
+ items = [BufferItem(**kwargs) for kwargs in batch_kwargs]
+ return items
+
+
+def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left") -> torch.Tensor:
+ assert side in ("left", "right")
+ max_len = max(seq.size(0) for seq in sequences)
+ padded_sequences = []
+ for seq in sequences:
+ pad_len = max_len - seq.size(0)
+ padding = (pad_len, 0) if side == "left" else (0, pad_len)
+ padded_sequences.append(F.pad(seq, padding))
+ return torch.stack(padded_sequences, dim=0)
+
+
+def make_experience_batch(items: List[BufferItem]) -> Experience:
+ kwargs = {}
+ to_pad_keys = set(("action_log_probs", "action_mask"))
+ keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask")
+ for key in keys:
+ vals = [getattr(item, key) for item in items]
+ if key in to_pad_keys:
+ batch_data = _zero_pad_sequences(vals)
+ else:
+ batch_data = torch.stack(vals, dim=0)
+ kwargs[key] = batch_data
+ return Experience(**kwargs)
diff --git a/applications/Chat/coati/experience_maker/__init__.py b/applications/Chat/coati/experience_maker/__init__.py
index 39ca7576b22761807b5804cc15d82e15fdde42cd..06452292e77c4a6e32ec5ed2a7c996e2b7fa2b64 100644
--- a/applications/Chat/coati/experience_maker/__init__.py
+++ b/applications/Chat/coati/experience_maker/__init__.py
@@ -1,4 +1,4 @@
from .base import Experience, ExperienceMaker
from .naive import NaiveExperienceMaker
-__all__ = ['Experience', 'ExperienceMaker', 'NaiveExperienceMaker']
+__all__ = ["Experience", "ExperienceMaker", "NaiveExperienceMaker"]
diff --git a/applications/Chat/coati/experience_maker/base.py b/applications/Chat/coati/experience_maker/base.py
index ff75852576c848625d8786298a46f59e6a395598..0731f6e0f97f3cb6efb7a25621763ce7350d6ab6 100644
--- a/applications/Chat/coati/experience_maker/base.py
+++ b/applications/Chat/coati/experience_maker/base.py
@@ -3,14 +3,13 @@ from dataclasses import dataclass
from typing import Optional
import torch
-import torch.nn as nn
-from coati.models.base import Actor
+from coati.models.base import Actor, Critic, RewardModel
@dataclass
class Experience:
"""Experience is a batch of data.
- These data should have the the sequence length and number of actions.
+ These data should have the sequence length and number of actions.
Left padding for sequences is applied.
Shapes of each tensor:
@@ -24,6 +23,7 @@ class Experience:
"A" is the number of actions.
"""
+
sequences: torch.Tensor
action_log_probs: torch.Tensor
values: torch.Tensor
@@ -58,20 +58,13 @@ class Experience:
class ExperienceMaker(ABC):
-
- def __init__(self,
- actor: Actor,
- critic: nn.Module,
- reward_model: nn.Module,
- initial_model: Actor,
- kl_coef: float = 0.1) -> None:
+ def __init__(self, actor: Actor, critic: Critic, reward_model: RewardModel, initial_model: Actor) -> None:
super().__init__()
self.actor = actor
self.critic = critic
self.reward_model = reward_model
self.initial_model = initial_model
- self.kl_coef = kl_coef
@abstractmethod
- def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
+ def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience:
pass
diff --git a/applications/Chat/coati/experience_maker/naive.py b/applications/Chat/coati/experience_maker/naive.py
index 94546eeb28e787df620d15ec631b1776dc6bf282..941e1994b148ffa45b57287ee560e138f0d90923 100644
--- a/applications/Chat/coati/experience_maker/naive.py
+++ b/applications/Chat/coati/experience_maker/naive.py
@@ -1,5 +1,9 @@
import torch
-from coati.models.utils import compute_reward, normalize
+import torch.nn.functional as F
+from coati.models.base import Actor, Critic, RewardModel
+from coati.models.generation import generate
+from coati.models.utils import calc_action_log_probs, compute_reward
+from transformers import PreTrainedTokenizer
from .base import Experience, ExperienceMaker
@@ -9,6 +13,19 @@ class NaiveExperienceMaker(ExperienceMaker):
Naive experience maker.
"""
+ def __init__(
+ self,
+ actor: Actor,
+ critic: Critic,
+ reward_model: RewardModel,
+ initial_model: Actor,
+ tokenizer: PreTrainedTokenizer,
+ kl_coef: float = 0.1,
+ ) -> None:
+ super().__init__(actor, critic, reward_model, initial_model)
+ self.tokenizer = tokenizer
+ self.kl_coef = kl_coef
+
@torch.no_grad()
def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
self.actor.eval()
@@ -16,14 +33,33 @@ class NaiveExperienceMaker(ExperienceMaker):
self.initial_model.eval()
self.reward_model.eval()
- sequences, attention_mask, action_mask = self.actor.generate(input_ids,
- return_action_mask=True,
- **generate_kwargs)
+ # generate sequences
+ sequences = generate(self.actor, input_ids, self.tokenizer, **generate_kwargs)
+
+ # calculate auxiliary tensors
+ attention_mask = None
+ pad_token_id = self.tokenizer.pad_token_id
+ if pad_token_id is not None:
+ attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
+
+ input_len = input_ids.size(1)
+ eos_token_id = self.tokenizer.eos_token_id
+ if eos_token_id is None:
+ action_mask = torch.ones_like(sequences, dtype=torch.bool)
+ else:
+ # left padding may be applied, only mask action
+ action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
+ action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
+ action_mask[:, :input_len] = False
+ action_mask = action_mask[:, 1:]
+ action_mask = action_mask[:, -(sequences.size(1) - input_len) :]
num_actions = action_mask.size(1)
- action_log_probs = self.actor(sequences, num_actions, attention_mask)
- base_action_log_probs = self.initial_model(sequences, num_actions, attention_mask)
- value = self.critic(sequences, action_mask, attention_mask)
+ actor_output = self.actor(sequences, attention_mask)["logits"]
+ action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
+ base_model_output = self.initial_model(sequences, attention_mask)["logits"]
+ base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
+ value = self.critic(sequences, attention_mask)
r = self.reward_model(sequences, attention_mask)
reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)
diff --git a/applications/Chat/coati/kernels/__init__.py b/applications/Chat/coati/kernels/__init__.py
index 230eedf7ecba9531dc53d9fafde34a15af806cb3..96d40c7c4709ef7ca56c3040b8d65121f4e4a85d 100644
--- a/applications/Chat/coati/kernels/__init__.py
+++ b/applications/Chat/coati/kernels/__init__.py
@@ -1,6 +1,6 @@
from .wrapper import convert_to_xformer_model, recover_from_xformer_model
__all__ = [
- 'convert_to_xformer_model',
- 'recover_from_xformer_model',
+ "convert_to_xformer_model",
+ "recover_from_xformer_model",
]
diff --git a/applications/Chat/coati/kernels/opt_attn.py b/applications/Chat/coati/kernels/opt_attn.py
index c10f341e94a3bd32511f6da0c50dbf075136935d..d1eb139187f392071d6cd15b7c3bc69bdd02a3bb 100644
--- a/applications/Chat/coati/kernels/opt_attn.py
+++ b/applications/Chat/coati/kernels/opt_attn.py
@@ -21,11 +21,12 @@ class XOPTAttention(OPTAttention):
output_attentions: bool = False,
) -> Tuple[Tensor, Optional[Tensor], Optional[Tuple[Tensor]]]:
if not self.training:
- return super().forward(hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask,
- output_attentions)
+ return super().forward(
+ hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, output_attentions
+ )
"""Input shape: Batch x Time x Channel"""
- assert layer_head_mask is None, 'Xformers attention does not support layer_head_mask'
- assert not output_attentions, 'Xformers attention does not support output_attentions'
+ assert layer_head_mask is None, "Xformers attention does not support layer_head_mask"
+ assert not output_attentions, "Xformers attention does not support output_attentions"
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
@@ -69,15 +70,17 @@ class XOPTAttention(OPTAttention):
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
- attn_output = xops.memory_efficient_attention(query_states,
- key_states,
- value_states,
- attn_bias=xops.LowerTriangularMask(),
- p=self.dropout if self.training else 0.0,
- scale=self.scaling)
+ attn_output = xops.memory_efficient_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_bias=xops.LowerTriangularMask(),
+ p=self.dropout if self.training else 0.0,
+ scale=self.scaling,
+ )
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned aross GPUs when using tensor-parallelism.
+ # partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
diff --git a/applications/Chat/coati/models/__init__.py b/applications/Chat/coati/models/__init__.py
index 709bc5ac0948ec13ca251b77a944b96e6543e97a..ad4a525b4af29adbf678b492d596c6410ea9e964 100644
--- a/applications/Chat/coati/models/__init__.py
+++ b/applications/Chat/coati/models/__init__.py
@@ -1,8 +1,15 @@
from .base import Actor, Critic, RewardModel
from .lora import LoRAModule, convert_to_lora_module
-from .loss import LogExpLoss, LogSigLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss
+from .loss import LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
__all__ = [
- 'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss',
- 'LoRAModule', 'convert_to_lora_module'
+ "Actor",
+ "Critic",
+ "RewardModel",
+ "PolicyLoss",
+ "ValueLoss",
+ "LogSigLoss",
+ "LogExpLoss",
+ "LoRAModule",
+ "convert_to_lora_module",
]
diff --git a/applications/Chat/coati/models/base/__init__.py b/applications/Chat/coati/models/base/__init__.py
index fe4152f2b760bf38924f521ee1d70a89074c8004..5c9905bb2224bb265324dfe37a30a637838b62f4 100644
--- a/applications/Chat/coati/models/base/__init__.py
+++ b/applications/Chat/coati/models/base/__init__.py
@@ -1,3 +1,5 @@
+from typing import Union
+
import torch.nn as nn
from .actor import Actor
@@ -5,10 +7,10 @@ from .critic import Critic
from .reward_model import RewardModel
-def get_base_model(model: nn.Module) -> nn.Module:
+def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module:
"""Get the base model of our wrapper classes.
- For Actor, it's base model is ``actor.model`` and it's usually a ``transformers.PreTrainedModel``.
- For Critic and RewardModel, it's base model is itself.
+ For Actor, Critic and RewardModel, return ``model.model``,
+ it's usually a ``transformers.PreTrainedModel``.
Args:
model (nn.Module): model to get base model from
@@ -16,9 +18,10 @@ def get_base_model(model: nn.Module) -> nn.Module:
Returns:
nn.Module: the base model
"""
- if isinstance(model, Actor):
- return model.get_base_model()
- return model
+ assert isinstance(
+ model, (Actor, Critic, RewardModel)
+ ), f"Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first."
+ return model.model
-__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model']
+__all__ = ["Actor", "Critic", "RewardModel", "get_base_model"]
diff --git a/applications/Chat/coati/models/base/actor.py b/applications/Chat/coati/models/base/actor.py
index 71fbf7bbae7d90875a202b5e4e56e431efc1468e..8b2b81ed071ce18c0a2a0cf060fa22badb164885 100644
--- a/applications/Chat/coati/models/base/actor.py
+++ b/applications/Chat/coati/models/base/actor.py
@@ -1,12 +1,9 @@
-from typing import Optional, Tuple, Union
+from typing import Optional
import torch
import torch.nn as nn
-import torch.nn.functional as F
-from ..generation import generate
from ..lora import LoRAModule
-from ..utils import log_probs_from_logits
class Actor(LoRAModule):
@@ -19,47 +16,18 @@ class Actor(LoRAModule):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
+ def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = "none") -> None:
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.model = model
self.convert_to_lora()
- @torch.no_grad()
- def generate(
+ def forward(
self,
- input_ids: torch.Tensor,
- return_action_mask: bool = True,
- **kwargs
- ) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
- sequences = generate(self.model, input_ids, **kwargs)
- attention_mask = None
- pad_token_id = kwargs.get('pad_token_id', None)
- if pad_token_id is not None:
- attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
- if not return_action_mask:
- return sequences, attention_mask, None
- input_len = input_ids.size(1)
- eos_token_id = kwargs.get('eos_token_id', None)
- if eos_token_id is None:
- action_mask = torch.ones_like(sequences, dtype=torch.bool)
- else:
- # left padding may be applied, only mask action
- action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
- action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
- action_mask[:, :input_len] = False
- action_mask = action_mask[:, 1:]
- return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
-
- def forward(self,
- sequences: torch.LongTensor,
- num_actions: int,
- attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
- """Returns action log probs
- """
- output = self.model(sequences, attention_mask=attention_mask)
- logits = output['logits']
- log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
- return log_probs[:, -num_actions:]
-
- def get_base_model(self):
- return self.model
+ input_ids: torch.LongTensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ **model_kwargs,
+ ) -> torch.Tensor:
+ """Returns model output."""
+ output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
+ return output
+
diff --git a/applications/Chat/coati/models/base/critic.py b/applications/Chat/coati/models/base/critic.py
index e68a743a7762094c255df002565d61bdc611479f..8672365f5783240a63244fae19da4ab0d9ba20b9 100644
--- a/applications/Chat/coati/models/base/critic.py
+++ b/applications/Chat/coati/models/base/critic.py
@@ -1,10 +1,7 @@
-from typing import Optional
-
import torch
import torch.nn as nn
from ..lora import LoRAModule
-from ..utils import masked_mean
class Critic(LoRAModule):
@@ -19,36 +16,19 @@ class Critic(LoRAModule):
"""
def __init__(
- self,
- model: nn.Module,
- value_head: nn.Module,
- lora_rank: int = 0,
- lora_train_bias: str = 'none',
- use_action_mask: bool = False,
+ self, model: nn.Module, value_head: nn.Module, lora_rank: int = 0, lora_train_bias: str = "none"
) -> None:
-
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.model = model
self.value_head = value_head
- self.use_action_mask = use_action_mask
self.convert_to_lora()
- def forward(self,
- sequences: torch.LongTensor,
- action_mask: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def forward(self, sequences: torch.LongTensor, attention_mask: torch.Tensor) -> torch.Tensor:
outputs = self.model(sequences, attention_mask=attention_mask)
- last_hidden_states = outputs['last_hidden_state']
-
- values = self.value_head(last_hidden_states).squeeze(-1)
-
- if action_mask is not None and self.use_action_mask:
- num_actions = action_mask.size(1)
- prompt_mask = attention_mask[:, :-num_actions]
- values = values[:, :-num_actions]
- value = masked_mean(values, prompt_mask, dim=1)
- return value
-
- values = values[:, :-1]
- value = values.mean(dim=1)
- return value
+ last_hidden_states = outputs["last_hidden_state"]
+ sequence_lengths = torch.max(attention_mask * torch.arange(sequences.size(1), device=sequences.device), dim=1)[
+ 0
+ ]
+ sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths]
+ values = self.value_head(sequence_hidden_states).squeeze(1) # ensure shape is (B, )
+ return values
diff --git a/applications/Chat/coati/models/base/reward_model.py b/applications/Chat/coati/models/base/reward_model.py
index ce8c0a1d35687da2cb61b55c59be6c91c0f9c394..e9545d1cddafae62d9999b9262bd969a56c8cf31 100644
--- a/applications/Chat/coati/models/base/reward_model.py
+++ b/applications/Chat/coati/models/base/reward_model.py
@@ -17,11 +17,13 @@ class RewardModel(LoRAModule):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- model: nn.Module,
- value_head: Optional[nn.Module] = None,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
+ def __init__(
+ self,
+ model: nn.Module,
+ value_head: Optional[nn.Module] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ ) -> None:
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.model = model
self.convert_to_lora()
@@ -33,9 +35,12 @@ class RewardModel(LoRAModule):
else:
self.value_head = nn.Linear(model.config.n_embd, 1)
- def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def forward(self, sequences: torch.LongTensor, attention_mask: torch.Tensor) -> torch.Tensor:
outputs = self.model(sequences, attention_mask=attention_mask)
- last_hidden_states = outputs['last_hidden_state']
- values = self.value_head(last_hidden_states)[:, :-1]
- value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
- return value
+ last_hidden_states = outputs["last_hidden_state"]
+ sequence_lengths = torch.max(attention_mask * torch.arange(sequences.size(1), device=sequences.device), dim=1)[
+ 0
+ ]
+ sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths]
+ values = self.value_head(sequence_hidden_states).squeeze(1) # ensure shape is (B, )
+ return values
diff --git a/applications/Chat/coati/models/bloom/__init__.py b/applications/Chat/coati/models/bloom/__init__.py
index d0e7f7b1ef94e2a8434fa36efe6b5a2e4586f5b8..7af199a67d3b282d48b325795fdc8212e5e26381 100644
--- a/applications/Chat/coati/models/bloom/__init__.py
+++ b/applications/Chat/coati/models/bloom/__init__.py
@@ -2,4 +2,4 @@ from .bloom_actor import BLOOMActor
from .bloom_critic import BLOOMCritic
from .bloom_rm import BLOOMRM
-__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM']
+__all__ = ["BLOOMActor", "BLOOMCritic", "BLOOMRM"]
diff --git a/applications/Chat/coati/models/bloom/bloom_actor.py b/applications/Chat/coati/models/bloom/bloom_actor.py
index d7577f0964934955992f8d952c4e209a68a58829..73855a2245e7e2c4dfbb42c74be1e14a152a6262 100644
--- a/applications/Chat/coati/models/bloom/bloom_actor.py
+++ b/applications/Chat/coati/models/bloom/bloom_actor.py
@@ -1,7 +1,6 @@
from typing import Optional
-import torch
-from transformers import BloomConfig, BloomForCausalLM, BloomModel
+from transformers import BloomConfig, BloomForCausalLM
from ..base import Actor
@@ -18,12 +17,14 @@ class BLOOMActor(Actor):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: str = None,
- config: Optional[BloomConfig] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
+ def __init__(
+ self,
+ pretrained: str = None,
+ config: Optional[BloomConfig] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ ) -> None:
if pretrained is not None:
model = BloomForCausalLM.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/bloom/bloom_critic.py b/applications/Chat/coati/models/bloom/bloom_critic.py
index a32fb2e102f9b05bb5c62d97f215343ce466b840..b2d838f7ffc50bbac62c2950ad4d0c701beeb491 100644
--- a/applications/Chat/coati/models/bloom/bloom_critic.py
+++ b/applications/Chat/coati/models/bloom/bloom_critic.py
@@ -1,8 +1,7 @@
from typing import Optional
-import torch
import torch.nn as nn
-from transformers import BloomConfig, BloomForCausalLM, BloomModel
+from transformers import BloomConfig, BloomModel
from ..base import Critic
@@ -14,25 +13,24 @@ class BLOOMCritic(Critic):
Args:
pretrained (str): Pretrained model name or path.
config (BloomConfig): Model config.
- checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: str = None,
- config: Optional[BloomConfig] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none',
- **kwargs) -> None:
+ def __init__(
+ self,
+ pretrained: str = None,
+ config: Optional[BloomConfig] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ **kwargs,
+ ) -> None:
if pretrained is not None:
model = BloomModel.from_pretrained(pretrained)
elif config is not None:
model = BloomModel(config)
else:
model = BloomModel(BloomConfig())
- if checkpoint:
- model.gradient_checkpointing_enable()
+
value_head = nn.Linear(model.config.hidden_size, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
diff --git a/applications/Chat/coati/models/bloom/bloom_rm.py b/applications/Chat/coati/models/bloom/bloom_rm.py
index 22cfab441abb6666812a5875e02da77b28357e94..c09457ddc8c73190a9a9bef4a903588520c881ec 100644
--- a/applications/Chat/coati/models/bloom/bloom_rm.py
+++ b/applications/Chat/coati/models/bloom/bloom_rm.py
@@ -1,7 +1,7 @@
from typing import Optional
import torch.nn as nn
-from transformers import BloomConfig, BloomForCausalLM, BloomModel
+from transformers import BloomConfig, BloomModel
from ..base import RewardModel
@@ -13,25 +13,24 @@ class BLOOMRM(RewardModel):
Args:
pretrained (str): Pretrained model name or path.
config (BloomConfig): Model config.
- checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: str = None,
- config: Optional[BloomConfig] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
+ def __init__(
+ self,
+ pretrained: str = None,
+ config: Optional[BloomConfig] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ ) -> None:
if pretrained is not None:
model = BloomModel.from_pretrained(pretrained)
elif config is not None:
model = BloomModel(config)
else:
model = BloomModel(BloomConfig())
- if checkpoint:
- model.gradient_checkpointing_enable()
+
value_head = nn.Linear(model.config.hidden_size, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias)
diff --git a/applications/Chat/coati/models/chatglm/__init__.py b/applications/Chat/coati/models/chatglm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5956f5a8e91b8095de6b32edd4005051f73dc7bc
--- /dev/null
+++ b/applications/Chat/coati/models/chatglm/__init__.py
@@ -0,0 +1,3 @@
+from .chatglm_actor import ChatGLMActor
+
+__all__ = ["ChatGLMActor"]
diff --git a/applications/Chat/coati/models/chatglm/chatglm_actor.py b/applications/Chat/coati/models/chatglm/chatglm_actor.py
new file mode 100644
index 0000000000000000000000000000000000000000..00a61561ee47d0a60dd981db2b95603e4d20eda3
--- /dev/null
+++ b/applications/Chat/coati/models/chatglm/chatglm_actor.py
@@ -0,0 +1,31 @@
+from typing import Optional
+
+from ..base import Actor
+from .configuration_chatglm import ChatGLMConfig
+from .modeling_chatglm import ChatGLMForConditionalGeneration
+
+
+class ChatGLMActor(Actor):
+ """
+ ChatGLM Actor model.
+
+ Args:
+ pretrained (str): Pretrained model name or path.
+ config (ChatGLMConfig): Model config.
+ checkpoint (bool): Enable gradient checkpointing.
+
+ do not support lora for now.
+ """
+
+ def __init__(
+ self, pretrained: str = None, config: Optional[ChatGLMConfig] = None, checkpoint: bool = False
+ ) -> None:
+ if pretrained is not None:
+ model = ChatGLMForConditionalGeneration.from_pretrained(pretrained)
+ elif config is not None:
+ model = ChatGLMForConditionalGeneration(config)
+ else:
+ model = ChatGLMForConditionalGeneration(ChatGLMConfig())
+ if checkpoint:
+ model.gradient_checkpointing_enable()
+ super().__init__(model, lora_rank=0, lora_train_bias="none")
diff --git a/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py b/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..221ef044b470566ed2956f799086adc98ee62e09
--- /dev/null
+++ b/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py
@@ -0,0 +1,442 @@
+"""
+This code is copied from https://huggingface.co/THUDM/chatglm-6b/blob/main/tokenization_chatglm.py
+"""
+"""Tokenization classes for ChatGLM."""
+import os
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+import sentencepiece as spm
+from transformers.tokenization_utils import PreTrainedTokenizer
+from transformers.tokenization_utils_base import BatchEncoding, EncodedInput
+from transformers.utils import PaddingStrategy, logging
+
+logger = logging.get_logger(__name__)
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "THUDM/chatglm-6b": 2048,
+}
+
+
+class TextTokenizer:
+ def __init__(self, model_path):
+ self.sp = spm.SentencePieceProcessor()
+ self.sp.Load(model_path)
+ self.num_tokens = self.sp.vocab_size()
+
+ def encode(self, text):
+ return self.sp.EncodeAsIds(text)
+
+ def decode(self, ids: List[int]):
+ return self.sp.DecodeIds(ids)
+
+ def tokenize(self, text):
+ return self.sp.EncodeAsPieces(text)
+
+ def convert_tokens_to_string(self, tokens):
+ return self.sp.DecodePieces(tokens)
+
+ def convert_tokens_to_ids(self, tokens):
+ return [self.sp.PieceToId(token) for token in tokens]
+
+ def convert_token_to_id(self, token):
+ return self.sp.PieceToId(token)
+
+ def convert_id_to_token(self, idx):
+ return self.sp.IdToPiece(idx)
+
+ def __len__(self):
+ return self.num_tokens
+
+
+class SPTokenizer:
+ def __init__(
+ self,
+ vocab_file,
+ num_image_tokens=20000,
+ max_blank_length=80,
+ byte_fallback=True,
+ ):
+ assert vocab_file is not None
+ self.vocab_file = vocab_file
+ self.num_image_tokens = num_image_tokens
+ self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "", "", "", "", ""]
+ self.max_blank_length = max_blank_length
+ self.byte_fallback = byte_fallback
+ self.text_tokenizer = TextTokenizer(vocab_file)
+
+ def _get_text_tokenizer(self):
+ return self.text_tokenizer
+
+ @staticmethod
+ def get_blank_token(length: int):
+ assert length >= 2
+ return f"<|blank_{length}|>"
+
+ @staticmethod
+ def get_tab_token():
+ return f"<|tab|>"
+
+ @property
+ def num_text_tokens(self):
+ return self.text_tokenizer.num_tokens
+
+ @property
+ def num_tokens(self):
+ return self.num_image_tokens + self.num_text_tokens
+
+ @staticmethod
+ def _encode_whitespaces(text: str, max_len: int = 80):
+ text = text.replace("\t", SPTokenizer.get_tab_token())
+ for i in range(max_len, 1, -1):
+ text = text.replace(" " * i, SPTokenizer.get_blank_token(i))
+ return text
+
+ def _preprocess(self, text: str, linebreak=True, whitespaces=True):
+ if linebreak:
+ text = text.replace("\n", "")
+ if whitespaces:
+ text = self._encode_whitespaces(text, max_len=self.max_blank_length)
+ return text
+
+ def encode(self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True) -> List[int]:
+ """
+ @param text: Text to encode.
+ @param linebreak: Whether to encode newline (\n) in text.
+ @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
+ @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
+ @param add_dummy_prefix: Whether to add dummy blank space in the beginning.
+ """
+ text = self._preprocess(text, linebreak, whitespaces)
+ if not add_dummy_prefix:
+ text = "" + text
+ tmp = self._get_text_tokenizer().encode(text)
+ tokens = [x + self.num_image_tokens for x in tmp]
+ return tokens if add_dummy_prefix else tokens[2:]
+
+ def postprocess(self, text):
+ text = text.replace("", "\n")
+ text = text.replace(SPTokenizer.get_tab_token(), "\t")
+ for i in range(2, self.max_blank_length + 1):
+ text = text.replace(self.get_blank_token(i), " " * i)
+ return text
+
+ def decode(self, text_ids: List[int]) -> str:
+ ids = [int(_id) - self.num_image_tokens for _id in text_ids]
+ ids = [_id for _id in ids if _id >= 0]
+ text = self._get_text_tokenizer().decode(ids)
+ text = self.postprocess(text)
+ return text
+
+ def decode_tokens(self, tokens: List[str]) -> str:
+ text = self._get_text_tokenizer().convert_tokens_to_string(tokens)
+ text = self.postprocess(text)
+ return text
+
+ def tokenize(self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True) -> List[str]:
+ """
+ @param text: Text to encode.
+ @param linebreak: Whether to encode newline (\n) in text.
+ @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
+ @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
+ @param add_dummy_prefix: Whether to add dummy blank space in the beginning.
+ """
+ text = self._preprocess(text, linebreak, whitespaces)
+ if not add_dummy_prefix:
+ text = "" + text
+ tokens = self._get_text_tokenizer().tokenize(text)
+ return tokens if add_dummy_prefix else tokens[2:]
+
+ def __getitem__(self, x: Union[int, str]):
+ if isinstance(x, int):
+ if x < self.num_image_tokens:
+ return "".format(x)
+ else:
+ return self.text_tokenizer.convert_id_to_token(x - self.num_image_tokens)
+ elif isinstance(x, str):
+ if x.startswith("") and x[7:-1].isdigit():
+ return int(x[7:-1])
+ else:
+ return self.text_tokenizer.convert_token_to_id(x) + self.num_image_tokens
+ else:
+ raise ValueError("The key should be str or int.")
+
+
+class ChatGLMTokenizer(PreTrainedTokenizer):
+ """
+ Construct a ChatGLM tokenizer. Based on byte-level Byte-Pair-Encoding.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ """
+
+ vocab_files_names = {"vocab_file": "ice_text.model"}
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ model_input_names = ["input_ids", "attention_mask", "position_ids"]
+
+ def __init__(
+ self,
+ vocab_file,
+ do_lower_case=False,
+ remove_space=False,
+ bos_token="",
+ eos_token="",
+ end_token="",
+ mask_token="[MASK]",
+ gmask_token="[gMASK]",
+ padding_side="left",
+ pad_token="",
+ unk_token="",
+ num_image_tokens=20000,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ do_lower_case=do_lower_case,
+ remove_space=remove_space,
+ padding_side=padding_side,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ end_token=end_token,
+ mask_token=mask_token,
+ gmask_token=gmask_token,
+ pad_token=pad_token,
+ unk_token=unk_token,
+ num_image_tokens=num_image_tokens,
+ **kwargs,
+ )
+
+ self.do_lower_case = do_lower_case
+ self.remove_space = remove_space
+ self.vocab_file = vocab_file
+
+ self.bos_token = bos_token
+ self.eos_token = eos_token
+ self.end_token = end_token
+ self.mask_token = mask_token
+ self.gmask_token = gmask_token
+
+ self.sp_tokenizer = SPTokenizer(vocab_file, num_image_tokens=num_image_tokens)
+
+ """ Initialisation """
+
+ @property
+ def gmask_token_id(self) -> Optional[int]:
+ if self.gmask_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.gmask_token)
+
+ @property
+ def end_token_id(self) -> Optional[int]:
+ """
+ `Optional[int]`: Id of the end of context token in the vocabulary. Returns `None` if the token has not been
+ set.
+ """
+ if self.end_token is None:
+ return None
+ return self.convert_tokens_to_ids(self.end_token)
+
+ @property
+ def vocab_size(self):
+ """Returns vocab size"""
+ return self.sp_tokenizer.num_tokens
+
+ def get_vocab(self):
+ """Returns vocab as a dict"""
+ vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def preprocess_text(self, inputs):
+ if self.remove_space:
+ outputs = " ".join(inputs.strip().split())
+ else:
+ outputs = inputs
+
+ if self.do_lower_case:
+ outputs = outputs.lower()
+
+ return outputs
+
+ def _tokenize(self, text, **kwargs):
+ """Returns a tokenized string."""
+ text = self.preprocess_text(text)
+
+ seq = self.sp_tokenizer.tokenize(text)
+
+ return seq
+
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
+ return self.sp_tokenizer.decode_tokens(tokens)
+
+ def _decode(self, token_ids: Union[int, List[int]], **kwargs) -> str:
+ if isinstance(token_ids, int):
+ token_ids = [token_ids]
+ if len(token_ids) == 0:
+ return ""
+ if self.pad_token_id in token_ids: # remove pad
+ token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
+ return super()._decode(token_ids, **kwargs)
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.sp_tokenizer[token]
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.sp_tokenizer[index]
+
+ def save_vocabulary(self, save_directory, filename_prefix=None):
+ """
+ Save the vocabulary and special tokens file to a directory.
+
+ Args:
+ save_directory (`str`):
+ The directory in which to save the vocabulary.
+ filename_prefix (`str`, *optional*):
+ An optional prefix to add to the named of the saved files.
+
+ Returns:
+ `Tuple(str)`: Paths to the files saved.
+ """
+ if os.path.isdir(save_directory):
+ vocab_file = os.path.join(save_directory, self.vocab_files_names["vocab_file"])
+ else:
+ vocab_file = save_directory
+
+ with open(self.vocab_file, "rb") as fin:
+ proto_str = fin.read()
+
+ with open(vocab_file, "wb") as writer:
+ writer.write(proto_str)
+
+ return (vocab_file,)
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A BERT sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ gmask_id = self.sp_tokenizer[self.gmask_token]
+ self.sp_tokenizer[self.eos_token]
+ token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]]
+ if token_ids_1 is not None:
+ token_ids_0 = token_ids_0 + token_ids_1
+ return token_ids_0
+
+ def _pad(
+ self,
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
+ max_length: Optional[int] = None,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ pad_to_multiple_of: Optional[int] = None,
+ return_attention_mask: Optional[bool] = None,
+ ) -> dict:
+ """
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
+
+ Args:
+ encoded_inputs:
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
+ max_length: maximum length of the returned list and optionally padding length (see below).
+ Will truncate by taking into account the special tokens.
+ padding_strategy: PaddingStrategy to use for padding.
+
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
+ The tokenizer padding sides are defined in self.padding_side:
+
+ - 'left': pads on the left of the sequences
+ - 'right': pads on the right of the sequences
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
+ `>= 7.5` (Volta).
+ return_attention_mask:
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
+ """
+ # Load from model defaults
+ bos_token_id = self.sp_tokenizer[self.bos_token]
+ mask_token_id = self.sp_tokenizer[self.mask_token]
+ gmask_token_id = self.sp_tokenizer[self.gmask_token]
+ assert self.padding_side == "left"
+
+ required_input = encoded_inputs[self.model_input_names[0]]
+ seq_length = len(required_input)
+
+ if padding_strategy == PaddingStrategy.LONGEST:
+ max_length = len(required_input)
+
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
+
+ # Initialize attention mask if not present.
+ if max_length is not None:
+ if "attention_mask" not in encoded_inputs:
+ if bos_token_id in required_input:
+ context_length = required_input.index(bos_token_id)
+ else:
+ context_length = seq_length
+ attention_mask = np.ones((1, seq_length, seq_length))
+ attention_mask = np.tril(attention_mask)
+ attention_mask[:, :, :context_length] = 1
+ attention_mask = np.bool_(attention_mask < 0.5)
+ encoded_inputs["attention_mask"] = attention_mask
+
+ if "position_ids" not in encoded_inputs:
+ if bos_token_id in required_input:
+ context_length = required_input.index(bos_token_id)
+ else:
+ context_length = seq_length
+ position_ids = np.arange(seq_length, dtype=np.int64)
+ mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
+ if mask_token in required_input:
+ mask_position = required_input.index(mask_token)
+ position_ids[context_length:] = mask_position
+ block_position_ids = np.concatenate(
+ [
+ np.zeros(context_length, dtype=np.int64),
+ np.arange(1, seq_length - context_length + 1, dtype=np.int64),
+ ]
+ )
+ encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
+
+ if needs_to_be_padded:
+ difference = max_length - len(required_input)
+
+ if "attention_mask" in encoded_inputs:
+ encoded_inputs["attention_mask"] = np.pad(
+ encoded_inputs["attention_mask"],
+ pad_width=[(0, 0), (difference, 0), (difference, 0)],
+ mode="constant",
+ constant_values=True,
+ )
+ if "token_type_ids" in encoded_inputs:
+ encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
+ "token_type_ids"
+ ]
+ if "special_tokens_mask" in encoded_inputs:
+ encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
+ if "position_ids" in encoded_inputs:
+ encoded_inputs["position_ids"] = np.pad(
+ encoded_inputs["position_ids"], pad_width=[(0, 0), (difference, 0)]
+ )
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
+
+ return encoded_inputs
diff --git a/applications/Chat/coati/models/chatglm/configuration_chatglm.py b/applications/Chat/coati/models/chatglm/configuration_chatglm.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6d2ccd187153d293aee713b9665f3be00707d6f
--- /dev/null
+++ b/applications/Chat/coati/models/chatglm/configuration_chatglm.py
@@ -0,0 +1,101 @@
+"""
+This code is copied from https://huggingface.co/THUDM/chatglm-6b/resolve/main/configuration_chatglm.py
+"""
+
+""" ChatGLM model configuration """
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+
+class ChatGLMConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`~ChatGLMModel`].
+ It is used to instantiate an ChatGLM model according to the specified arguments, defining the model
+ architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
+ the ChatGLM-6B [THUDM/ChatGLM-6B](https://huggingface.co/THUDM/chatglm-6b) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used
+ to control the model outputs. Read the documentation from [`PretrainedConfig`]
+ for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 150528):
+ Vocabulary size of the ChatGLM-6B model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`~ChatGLMModel`] or
+ [`~TFChatGLMModel`].
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 28):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ inner_hidden_size (`int`, *optional*, defaults to 16384):
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ max_sequence_length (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with.
+ Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
+ layernorm_epsilon (`float`, *optional*, defaults to 1e-5):
+ The epsilon used by the layer normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether the model should return the last key/values attentions (not used by all models).
+ Example:
+
+ ```python
+ >>> from configuration_chatglm import ChatGLMConfig
+ >>> from modeling_chatglm import ChatGLMModel
+
+ >>> # Initializing a ChatGLM-6B THUDM/ChatGLM-6B style configuration
+ >>> configuration = ChatGLMConfig()
+
+ >>> # Initializing a model from the THUDM/ChatGLM-6B style configuration
+ >>> model = ChatGLMModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "chatglm"
+
+ def __init__(
+ self,
+ vocab_size=130528,
+ hidden_size=4096,
+ num_layers=28,
+ num_attention_heads=32,
+ layernorm_epsilon=1e-5,
+ use_cache=True,
+ bos_token_id=130004,
+ eos_token_id=130005,
+ mask_token_id=130000,
+ gmask_token_id=130001,
+ pad_token_id=3,
+ max_sequence_length=2048,
+ inner_hidden_size=16384,
+ position_encoding_2d=True,
+ quantization_bit=0,
+ pre_seq_len=None,
+ prefix_projection=False,
+ **kwargs,
+ ):
+ self.num_layers = num_layers
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_attention_heads = num_attention_heads
+ self.max_sequence_length = max_sequence_length
+ self.layernorm_epsilon = layernorm_epsilon
+ self.inner_hidden_size = inner_hidden_size
+ self.use_cache = use_cache
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+ self.pad_token_id = pad_token_id
+ self.mask_token_id = mask_token_id
+ self.gmask_token_id = gmask_token_id
+ self.position_encoding_2d = position_encoding_2d
+ self.quantization_bit = quantization_bit
+ self.pre_seq_len = pre_seq_len
+ self.prefix_projection = prefix_projection
+
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
diff --git a/applications/Chat/coati/models/chatglm/modeling_chatglm.py b/applications/Chat/coati/models/chatglm/modeling_chatglm.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1d15c68ffd8bbc9686741da9f11461c09f02507
--- /dev/null
+++ b/applications/Chat/coati/models/chatglm/modeling_chatglm.py
@@ -0,0 +1,1477 @@
+"""
+This code is copied from https://huggingface.co/THUDM/chatglm-6b/resolve/main/modeling_chatglm.py
+"""
+
+""" PyTorch ChatGLM model. """
+
+import copy
+import math
+import os
+import re
+import sys
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss, LayerNorm
+from torch.nn.utils import skip_init
+from transformers.generation.logits_process import LogitsProcessor
+from transformers.generation.utils import GenerationConfig, LogitsProcessorList, ModelOutput, StoppingCriteriaList
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithPast,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+)
+
+from .configuration_chatglm import ChatGLMConfig
+
+# flags required to enable jit fusion kernels
+
+if sys.platform != "darwin":
+ torch._C._jit_set_profiling_mode(False)
+ torch._C._jit_set_profiling_executor(False)
+ torch._C._jit_override_can_fuse_on_cpu(True)
+ torch._C._jit_override_can_fuse_on_gpu(True)
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM-6B"
+_CONFIG_FOR_DOC = "ChatGLM6BConfig"
+
+CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "THUDM/chatglm-6b",
+ # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm
+]
+
+
+class InvalidScoreLogitsProcessor(LogitsProcessor):
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
+ scores.zero_()
+ scores[..., 5] = 5e4
+ return scores
+
+
+def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
+ """Load tf checkpoints in a pytorch model."""
+ try:
+ import re
+
+ import numpy as np
+ import tensorflow as tf
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+ tf_path = os.path.abspath(tf_checkpoint_path)
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_path)
+ names = []
+ arrays = []
+ for name, shape in init_vars:
+ logger.info(f"Loading TF weight {name} with shape {shape}")
+ array = tf.train.load_variable(tf_path, name)
+ names.append(name)
+ arrays.append(array)
+
+ for name, array in zip(names, arrays):
+ name = name.split("/")
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
+ # which are not required for using pretrained model
+ if any(
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
+ for n in name
+ ):
+ logger.info(f"Skipping {'/'.join(name)}")
+ continue
+ pointer = model
+ for m_name in name:
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
+ scope_names = re.split(r"_(\d+)", m_name)
+ else:
+ scope_names = [m_name]
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
+ pointer = getattr(pointer, "bias")
+ elif scope_names[0] == "output_weights":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "squad":
+ pointer = getattr(pointer, "classifier")
+ else:
+ try:
+ pointer = getattr(pointer, scope_names[0])
+ except AttributeError:
+ logger.info(f"Skipping {'/'.join(name)}")
+ continue
+ if len(scope_names) >= 2:
+ num = int(scope_names[1])
+ pointer = pointer[num]
+ if m_name[-11:] == "_embeddings":
+ pointer = getattr(pointer, "weight")
+ elif m_name == "kernel":
+ array = np.transpose(array)
+ try:
+ assert (
+ pointer.shape == array.shape
+ ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
+ except AssertionError as e:
+ e.args += (pointer.shape, array.shape)
+ raise
+ logger.info(f"Initialize PyTorch weight {name}")
+ pointer.data = torch.from_numpy(array)
+ return model
+
+
+class PrefixEncoder(torch.nn.Module):
+ """
+ The torch.nn model to encode the prefix
+ Input shape: (batch-size, prefix-length)
+ Output shape: (batch-size, prefix-length, 2*layers*hidden)
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.prefix_projection = config.prefix_projection
+ if self.prefix_projection:
+ # Use a two-layer MLP to encode the prefix
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
+ self.trans = torch.nn.Sequential(
+ torch.nn.Linear(config.hidden_size, config.hidden_size),
+ torch.nn.Tanh(),
+ torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2),
+ )
+ else:
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)
+
+ def forward(self, prefix: torch.Tensor):
+ if self.prefix_projection:
+ prefix_tokens = self.embedding(prefix)
+ past_key_values = self.trans(prefix_tokens)
+ else:
+ past_key_values = self.embedding(prefix)
+ return past_key_values
+
+
+@torch.jit.script
+def gelu_impl(x):
+ """OpenAI's gelu implementation."""
+ return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
+
+
+def gelu(x):
+ return gelu_impl(x)
+
+
+class RotaryEmbedding(torch.nn.Module):
+ def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
+ super().__init__()
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
+ inv_freq = inv_freq.half()
+ self.learnable = learnable
+ if learnable:
+ self.inv_freq = torch.nn.Parameter(inv_freq)
+ self.max_seq_len_cached = None
+ else:
+ self.register_buffer("inv_freq", inv_freq)
+ self.max_seq_len_cached = None
+ self.cos_cached = None
+ self.sin_cached = None
+ self.precision = precision
+
+ def _load_from_state_dict(
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ ):
+ pass
+
+ def forward(self, x, seq_dim=1, seq_len=None):
+ if seq_len is None:
+ seq_len = x.shape[seq_dim]
+ if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
+ self.max_seq_len_cached = None if self.learnable else seq_len
+ t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+ if self.precision == torch.bfloat16:
+ emb = emb.float()
+
+ # [sx, 1 (b * np), hn]
+ cos_cached = emb.cos()[:, None, :]
+ sin_cached = emb.sin()[:, None, :]
+ if self.precision == torch.bfloat16:
+ cos_cached = cos_cached.bfloat16()
+ sin_cached = sin_cached.bfloat16()
+ if self.learnable:
+ return cos_cached, sin_cached
+ self.cos_cached, self.sin_cached = cos_cached, sin_cached
+ return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
+
+ def _apply(self, fn):
+ if self.cos_cached is not None:
+ self.cos_cached = fn(self.cos_cached)
+ if self.sin_cached is not None:
+ self.sin_cached = fn(self.sin_cached)
+ return super()._apply(fn)
+
+
+def rotate_half(x):
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
+
+
+@torch.jit.script
+def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
+ # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
+ cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), F.embedding(
+ position_id, sin.squeeze(1)
+ ).unsqueeze(2)
+ q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
+ return q, k
+
+
+def attention_fn(
+ self,
+ query_layer,
+ key_layer,
+ value_layer,
+ attention_mask,
+ hidden_size_per_partition,
+ layer_id,
+ layer_past=None,
+ scaling_attention_score=True,
+ use_cache=False,
+):
+ if layer_past is not None:
+ past_key, past_value = layer_past[0], layer_past[1]
+ key_layer = torch.cat((past_key, key_layer), dim=0)
+ value_layer = torch.cat((past_value, value_layer), dim=0)
+
+ # seqlen, batch, num_attention_heads, hidden_size_per_attention_head
+ seq_len, b, nh, hidden_size = key_layer.shape
+
+ if use_cache:
+ present = (key_layer, value_layer)
+ else:
+ present = None
+
+ query_key_layer_scaling_coeff = float(layer_id + 1)
+ if scaling_attention_score:
+ query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff)
+
+ # ===================================
+ # Raw attention scores. [b, np, s, s]
+ # ===================================
+
+ # [b, np, sq, sk]
+ output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
+
+ # [sq, b, np, hn] -> [sq, b * np, hn]
+ query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
+ # [sk, b, np, hn] -> [sk, b * np, hn]
+ key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
+
+ matmul_result = torch.zeros(
+ 1,
+ 1,
+ 1,
+ dtype=query_layer.dtype,
+ device=query_layer.device,
+ )
+
+ matmul_result = torch.baddbmm(
+ matmul_result,
+ query_layer.transpose(0, 1), # [b * np, sq, hn]
+ key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
+ beta=0.0,
+ alpha=1.0,
+ )
+
+ # change view to [b, np, sq, sk]
+ attention_scores = matmul_result.view(*output_size)
+
+ if self.scale_mask_softmax:
+ self.scale_mask_softmax.scale = query_key_layer_scaling_coeff
+ attention_probs = self.scale_mask_softmax(attention_scores, attention_mask.contiguous())
+ else:
+ if not (attention_mask == 0).all():
+ # if auto-regressive, skip
+ attention_scores.masked_fill_(attention_mask, -10000.0)
+ dtype = attention_scores.dtype
+ attention_scores = attention_scores.float()
+ attention_scores = attention_scores * query_key_layer_scaling_coeff
+
+ attention_probs = F.softmax(attention_scores, dim=-1)
+
+ attention_probs = attention_probs.type(dtype)
+
+ # =========================
+ # Context layer. [sq, b, hp]
+ # =========================
+
+ # value_layer -> context layer.
+ # [sk, b, np, hn] --> [b, np, sq, hn]
+
+ # context layer shape: [b, np, sq, hn]
+ output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
+
+ # change view [sk, b * np, hn]
+ value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
+
+ # change view [b * np, sq, sk]
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
+
+ # matmul: [b * np, sq, hn]
+ context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
+
+ # change view [b, np, sq, hn]
+ context_layer = context_layer.view(*output_size)
+
+ # [b, np, sq, hn] --> [sq, b, np, hn]
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
+
+ # [sq, b, np, hn] --> [sq, b, hp]
+ new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, present, attention_probs)
+
+ return outputs
+
+
+def default_init(cls, *args, **kwargs):
+ return cls(*args, **kwargs)
+
+
+class SelfAttention(torch.nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ num_attention_heads,
+ layer_id,
+ hidden_size_per_attention_head=None,
+ bias=True,
+ params_dtype=torch.float,
+ position_encoding_2d=True,
+ empty_init=True,
+ ):
+ if empty_init:
+ init_method = skip_init
+ else:
+ init_method = default_init
+ super(SelfAttention, self).__init__()
+
+ self.layer_id = layer_id
+ self.hidden_size = hidden_size
+ self.hidden_size_per_partition = hidden_size
+ self.num_attention_heads = num_attention_heads
+ self.num_attention_heads_per_partition = num_attention_heads
+ self.position_encoding_2d = position_encoding_2d
+ self.rotary_emb = RotaryEmbedding(
+ self.hidden_size // (self.num_attention_heads * 2)
+ if position_encoding_2d
+ else self.hidden_size // self.num_attention_heads,
+ base=10000,
+ precision=torch.half,
+ learnable=False,
+ )
+
+ self.scale_mask_softmax = None
+
+ if hidden_size_per_attention_head is None:
+ self.hidden_size_per_attention_head = hidden_size // num_attention_heads
+ else:
+ self.hidden_size_per_attention_head = hidden_size_per_attention_head
+
+ self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
+
+ # Strided linear layer.
+ self.query_key_value = init_method(
+ torch.nn.Linear,
+ hidden_size,
+ 3 * self.inner_hidden_size,
+ bias=bias,
+ dtype=params_dtype,
+ )
+
+ self.dense = init_method(
+ torch.nn.Linear,
+ self.inner_hidden_size,
+ hidden_size,
+ bias=bias,
+ dtype=params_dtype,
+ )
+
+ @staticmethod
+ def attention_mask_func(attention_scores, attention_mask):
+ attention_scores.masked_fill_(attention_mask, -10000.0)
+ return attention_scores
+
+ def split_tensor_along_last_dim(self, tensor, num_partitions, contiguous_split_chunks=False):
+ """Split a tensor along its last dimension.
+ Arguments:
+ tensor: input tensor.
+ num_partitions: number of partitions to split the tensor
+ contiguous_split_chunks: If True, make each chunk contiguous
+ in memory.
+ """
+ # Get the size and dimension.
+ last_dim = tensor.dim() - 1
+ last_dim_size = tensor.size()[last_dim] // num_partitions
+ # Split.
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
+ # Note: torch.split does not create contiguous tensors by default.
+ if contiguous_split_chunks:
+ return tuple(chunk.contiguous() for chunk in tensor_list)
+
+ return tensor_list
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_ids,
+ attention_mask: torch.Tensor,
+ layer_id,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ ):
+ """
+ hidden_states: [seq_len, batch, hidden_size]
+ attention_mask: [(1, 1), seq_len, seq_len]
+ """
+
+ # [seq_len, batch, 3 * hidden_size]
+ mixed_raw_layer = self.query_key_value(hidden_states)
+
+ # [seq_len, batch, 3 * hidden_size] --> [seq_len, batch, num_attention_heads, 3 * hidden_size_per_attention_head]
+ new_tensor_shape = mixed_raw_layer.size()[:-1] + (
+ self.num_attention_heads_per_partition,
+ 3 * self.hidden_size_per_attention_head,
+ )
+ mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape)
+
+ # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]
+ (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3)
+
+ if self.position_encoding_2d:
+ q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1))
+ k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
+ cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1)
+ position_ids, block_position_ids = (
+ position_ids[:, 0, :].transpose(0, 1).contiguous(),
+ position_ids[:, 1, :].transpose(0, 1).contiguous(),
+ )
+ q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids)
+ q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids)
+ query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))
+ key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1))
+ else:
+ position_ids = position_ids.transpose(0, 1)
+ cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1)
+ # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]
+ query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids)
+
+ # [seq_len, batch, hidden_size]
+ context_layer, present, attention_probs = attention_fn(
+ self=self,
+ query_layer=query_layer,
+ key_layer=key_layer,
+ value_layer=value_layer,
+ attention_mask=attention_mask,
+ hidden_size_per_partition=self.hidden_size_per_partition,
+ layer_id=layer_id,
+ layer_past=layer_past,
+ use_cache=use_cache,
+ )
+
+ output = self.dense(context_layer)
+
+ outputs = (output, present)
+
+ if output_attentions:
+ outputs += (attention_probs,)
+
+ return outputs # output, present, attention_probs
+
+
+class GEGLU(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.activation_fn = F.gelu
+
+ def forward(self, x):
+ # dim=-1 breaks in jit for pt<1.10
+ x1, x2 = x.chunk(2, dim=(x.ndim - 1))
+ return x1 * self.activation_fn(x2)
+
+
+class GLU(torch.nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ inner_hidden_size=None,
+ layer_id=None,
+ bias=True,
+ activation_func=gelu,
+ params_dtype=torch.float,
+ empty_init=True,
+ ):
+ super(GLU, self).__init__()
+ if empty_init:
+ init_method = skip_init
+ else:
+ init_method = default_init
+ self.layer_id = layer_id
+ self.activation_func = activation_func
+
+ # Project to 4h.
+ self.hidden_size = hidden_size
+ if inner_hidden_size is None:
+ inner_hidden_size = 4 * hidden_size
+ self.inner_hidden_size = inner_hidden_size
+ self.dense_h_to_4h = init_method(
+ torch.nn.Linear,
+ self.hidden_size,
+ self.inner_hidden_size,
+ bias=bias,
+ dtype=params_dtype,
+ )
+ # Project back to h.
+ self.dense_4h_to_h = init_method(
+ torch.nn.Linear,
+ self.inner_hidden_size,
+ self.hidden_size,
+ bias=bias,
+ dtype=params_dtype,
+ )
+
+ def forward(self, hidden_states):
+ """
+ hidden_states: [seq_len, batch, hidden_size]
+ """
+
+ # [seq_len, batch, inner_hidden_size]
+ intermediate_parallel = self.dense_h_to_4h(hidden_states)
+
+ intermediate_parallel = self.activation_func(intermediate_parallel)
+
+ output = self.dense_4h_to_h(intermediate_parallel)
+
+ return output
+
+
+class GLMBlock(torch.nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ num_attention_heads,
+ layernorm_epsilon,
+ layer_id,
+ inner_hidden_size=None,
+ hidden_size_per_attention_head=None,
+ layernorm=LayerNorm,
+ use_bias=True,
+ params_dtype=torch.float,
+ num_layers=28,
+ position_encoding_2d=True,
+ empty_init=True,
+ ):
+ super(GLMBlock, self).__init__()
+ # Set output layer initialization if not provided.
+
+ self.layer_id = layer_id
+
+ # Layernorm on the input data.
+ self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
+
+ self.position_encoding_2d = position_encoding_2d
+
+ # Self attention.
+ self.attention = SelfAttention(
+ hidden_size,
+ num_attention_heads,
+ layer_id,
+ hidden_size_per_attention_head=hidden_size_per_attention_head,
+ bias=use_bias,
+ params_dtype=params_dtype,
+ position_encoding_2d=self.position_encoding_2d,
+ empty_init=empty_init,
+ )
+
+ # Layernorm on the input data.
+ self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
+
+ self.num_layers = num_layers
+
+ # GLU
+ self.mlp = GLU(
+ hidden_size,
+ inner_hidden_size=inner_hidden_size,
+ bias=use_bias,
+ layer_id=layer_id,
+ params_dtype=params_dtype,
+ empty_init=empty_init,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_ids,
+ attention_mask: torch.Tensor,
+ layer_id,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ ):
+ """
+ hidden_states: [seq_len, batch, hidden_size]
+ attention_mask: [(1, 1), seq_len, seq_len]
+ """
+
+ # Layer norm at the begining of the transformer layer.
+ # [seq_len, batch, hidden_size]
+ attention_input = self.input_layernorm(hidden_states)
+
+ # Self attention.
+ attention_outputs = self.attention(
+ attention_input,
+ position_ids,
+ attention_mask=attention_mask,
+ layer_id=layer_id,
+ layer_past=layer_past,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
+ attention_output = attention_outputs[0]
+
+ outputs = attention_outputs[1:]
+
+ # Residual connection.
+ alpha = (2 * self.num_layers) ** 0.5
+ hidden_states = attention_input * alpha + attention_output
+
+ mlp_input = self.post_attention_layernorm(hidden_states)
+
+ # MLP.
+ mlp_output = self.mlp(mlp_input)
+
+ # Second residual connection.
+ output = mlp_input * alpha + mlp_output
+
+ if use_cache:
+ outputs = (output,) + outputs
+ else:
+ outputs = (output,) + outputs[1:]
+
+ return outputs # hidden_states, present, attentions
+
+
+class ChatGLMPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and
+ a simple interface for downloading and loading pretrained models.
+ """
+
+ is_parallelizable = False
+ supports_gradient_checkpointing = True
+ config_class = ChatGLMConfig
+ base_model_prefix = "transformer"
+ _no_split_modules = ["GLMBlock"]
+
+ def __init__(self, *inputs, **kwargs):
+ super().__init__(*inputs, **kwargs)
+
+ def _init_weights(self, module: nn.Module):
+ """Initialize the weights."""
+ return
+
+ def get_masks(self, input_ids, device):
+ batch_size, seq_length = input_ids.shape
+ context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
+ attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
+ attention_mask.tril_()
+ for i, context_length in enumerate(context_lengths):
+ attention_mask[i, :, :context_length] = 1
+ attention_mask.unsqueeze_(1)
+ attention_mask = (attention_mask < 0.5).bool()
+
+ return attention_mask
+
+ def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None):
+ batch_size, seq_length = input_ids.shape
+ if use_gmasks is None:
+ use_gmasks = [False] * batch_size
+ context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
+ if self.position_encoding_2d:
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
+ for i, context_length in enumerate(context_lengths):
+ position_ids[i, context_length:] = mask_positions[i]
+ block_position_ids = [
+ torch.cat(
+ (
+ torch.zeros(context_length, dtype=torch.long, device=device),
+ torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1,
+ )
+ )
+ for context_length in context_lengths
+ ]
+ block_position_ids = torch.stack(block_position_ids, dim=0)
+ position_ids = torch.stack((position_ids, block_position_ids), dim=1)
+ else:
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
+ for i, context_length in enumerate(context_lengths):
+ if not use_gmasks[i]:
+ position_ids[i, context_length:] = mask_positions[i]
+
+ return position_ids
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, ChatGLMModel):
+ module.gradient_checkpointing = value
+
+
+CHATGLM_6B_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
+ usage and behavior.
+
+ Parameters:
+ config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the configuration.
+ Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+CHATGLM_6B_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`ChatGLM6BTokenizer`].
+ See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings.
+ Selected in the range `[0, config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert *input_ids* indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top.",
+ CHATGLM_6B_START_DOCSTRING,
+)
+class ChatGLMModel(ChatGLMPreTrainedModel):
+ """
+
+ The model can behave as an encoder (with only self-attention) as well
+ as a decoder, in which case a layer of cross-attention is added between
+ the self-attention layers, following the architecture described in [Attention is
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani,
+ Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+
+ To behave as an decoder the model needs to be initialized with the
+ `is_decoder` argument of the configuration set to `True`.
+ To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder`
+ argument and `add_cross_attention` set to `True`; an
+ `encoder_hidden_states` is then expected as an input to the forward pass.
+ """
+
+ def __init__(self, config: ChatGLMConfig, empty_init=True):
+ super().__init__(config)
+ if empty_init:
+ init_method = skip_init
+ else:
+ init_method = default_init
+ # recording parameters
+ self.max_sequence_length = config.max_sequence_length
+ self.hidden_size = config.hidden_size
+ self.params_dtype = torch.half
+ self.num_attention_heads = config.num_attention_heads
+ self.vocab_size = config.vocab_size
+ self.num_layers = config.num_layers
+ self.layernorm_epsilon = config.layernorm_epsilon
+ self.inner_hidden_size = config.inner_hidden_size
+ self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
+ self.position_encoding_2d = config.position_encoding_2d
+ self.pre_seq_len = config.pre_seq_len
+ self.prefix_projection = config.prefix_projection
+
+ self.word_embeddings = init_method(
+ torch.nn.Embedding, num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, dtype=self.params_dtype
+ )
+ self.gradient_checkpointing = False
+
+ def get_layer(layer_id):
+ return GLMBlock(
+ self.hidden_size,
+ self.num_attention_heads,
+ self.layernorm_epsilon,
+ layer_id,
+ inner_hidden_size=self.inner_hidden_size,
+ hidden_size_per_attention_head=self.hidden_size_per_attention_head,
+ layernorm=LayerNorm,
+ use_bias=True,
+ params_dtype=self.params_dtype,
+ position_encoding_2d=self.position_encoding_2d,
+ empty_init=empty_init,
+ )
+
+ self.layers = torch.nn.ModuleList([get_layer(layer_id) for layer_id in range(self.num_layers)])
+
+ # Final layer norm before output.
+ self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon)
+
+ if self.pre_seq_len is not None:
+ for param in self.parameters():
+ param.requires_grad = False
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = PrefixEncoder(config)
+ self.dropout = torch.nn.Dropout(0.1)
+
+ # total_params = sum(p.numel() for p in self.parameters())
+ # trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
+ # print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params))
+
+ def get_input_embeddings(self):
+ return self.word_embeddings
+
+ def set_input_embeddings(self, new_embeddings: torch.Tensor):
+ self.word_embeddings = new_embeddings
+
+ def get_prompt(self, batch_size, device, dtype=torch.half):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
+ past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
+ past_key_values = past_key_values.view(
+ batch_size,
+ self.pre_seq_len,
+ self.num_layers * 2,
+ self.num_attention_heads,
+ self.hidden_size // self.num_attention_heads,
+ )
+ # seq_len, b, nh, hidden_size
+ past_key_values = self.dropout(past_key_values)
+ past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
+ # past_key_values = [(v[0], v[1]) for v in past_key_values]
+ return past_key_values
+
+ @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ inputs_embeds: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape[:2]
+ elif inputs_embeds is not None:
+ batch_size, seq_length = inputs_embeds.shape[:2]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ if past_key_values is None:
+ if self.pre_seq_len is not None:
+ past_key_values = self.get_prompt(
+ batch_size=input_ids.shape[0], device=input_ids.device, dtype=inputs_embeds.dtype
+ )
+ else:
+ past_key_values = tuple([None] * len(self.layers))
+
+ if attention_mask is None:
+ attention_mask = self.get_masks(input_ids, device=input_ids.device)
+
+ if position_ids is None:
+ MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
+ seqs = input_ids.tolist()
+
+ mask_positions, use_gmasks = [], []
+ for seq in seqs:
+ mask_token = gMASK if gMASK in seq else MASK
+ use_gmask = mask_token == gMASK
+ mask_positions.append(seq.index(mask_token))
+ use_gmasks.append(use_gmask)
+
+ position_ids = self.get_position_ids(
+ input_ids, mask_positions=mask_positions, device=input_ids.device, use_gmasks=use_gmasks
+ )
+
+ if self.pre_seq_len is not None and attention_mask is not None:
+ prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
+ attention_mask.device
+ )
+ prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
+
+ # [seq_len, batch, hidden_size]
+ hidden_states = inputs_embeds.transpose(0, 1)
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ if attention_mask is None:
+ attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
+ else:
+ attention_mask = attention_mask.to(hidden_states.device)
+
+ for i, layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+ layer_past = past_key_values[i]
+
+ if self.gradient_checkpointing and self.training:
+ layer_ret = torch.utils.checkpoint.checkpoint(
+ layer,
+ hidden_states,
+ position_ids,
+ attention_mask,
+ torch.tensor(i),
+ layer_past,
+ use_cache,
+ output_attentions,
+ )
+ else:
+ layer_ret = layer(
+ hidden_states,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ layer_id=torch.tensor(i),
+ layer_past=layer_past,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_ret[0]
+
+ if use_cache:
+ presents = presents + (layer_ret[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],)
+
+ # Final layer norm.
+ hidden_states = self.final_layernorm(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
+ def __init__(self, config: ChatGLMConfig, empty_init=True):
+ super().__init__(config)
+ if empty_init:
+ init_method = skip_init
+ else:
+ init_method = default_init
+
+ # self.hidden_size = config.hidden_size
+ # self.params_dtype = torch.half
+ # self.vocab_size = config.vocab_size
+ self.max_sequence_length = config.max_sequence_length
+
+ self.position_encoding_2d = config.position_encoding_2d
+
+ self.transformer = ChatGLMModel(config, empty_init=empty_init)
+
+ self.lm_head = init_method(nn.Linear, config.hidden_size, config.vocab_size, bias=False, dtype=torch.half)
+
+ self.config = config
+
+ self.quantized = False
+
+ if self.config.quantization_bit:
+ self.quantize(self.config.quantization_bit, empty_init=True)
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def _update_model_kwargs_for_generation(
+ self,
+ outputs: ModelOutput,
+ model_kwargs: Dict[str, Any],
+ is_encoder_decoder: bool = False,
+ standardize_cache_format: bool = False,
+ ) -> Dict[str, Any]:
+ # update past_key_values
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
+ outputs, standardize_cache_format=standardize_cache_format
+ )
+
+ # update attention mask
+ if "attention_mask" in model_kwargs:
+ attention_mask = model_kwargs["attention_mask"]
+ if attention_mask is not None and attention_mask.dtype == torch.bool:
+ attention_mask = torch.cat(
+ [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3
+ )
+ new_attention_mask = attention_mask[:, :, -1:].clone()
+ new_attention_mask[..., -1] = False
+ model_kwargs["attention_mask"] = torch.cat([attention_mask, new_attention_mask], dim=2)
+
+ # update position ids
+ if "position_ids" in model_kwargs:
+ position_ids = model_kwargs["position_ids"]
+ new_position_id = position_ids[..., -1:].clone()
+ new_position_id[:, 1, :] += 1
+ model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1)
+
+ return model_kwargs
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids: torch.LongTensor,
+ past: Optional[torch.Tensor] = None,
+ past_key_values: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> dict:
+ batch_size, seq_length = input_ids.shape
+ MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
+ seqs = input_ids.tolist()
+ mask_positions, use_gmasks = [], []
+ for seq in seqs:
+ mask_token = gMASK if gMASK in seq else MASK
+ use_gmask = mask_token == gMASK
+ mask_positions.append(seq.index(mask_token))
+ use_gmasks.append(use_gmask)
+
+ # only last token for input_ids if past is not None
+ if past is not None or past_key_values is not None:
+ last_token = input_ids[:, -1].unsqueeze(-1)
+ if attention_mask is not None and attention_mask.dtype == torch.bool:
+ attention_mask = attention_mask[:, :, -1:]
+ else:
+ attention_mask = None
+ if position_ids is not None:
+ position_ids = position_ids[..., -1:]
+ else:
+ context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
+ if self.position_encoding_2d:
+ position_ids = torch.tensor(
+ [
+ [mask_position, seq_length - context_length]
+ for mask_position, context_length in zip(mask_positions, context_lengths)
+ ],
+ dtype=torch.long,
+ device=input_ids.device,
+ ).unsqueeze(-1)
+ else:
+ position_ids = torch.tensor(
+ [mask_position for mask_position in mask_positions], dtype=torch.long, device=input_ids.device
+ ).unsqueeze(-1)
+
+ if past is None:
+ past = past_key_values
+ return {
+ "input_ids": last_token,
+ "past_key_values": past,
+ "position_ids": position_ids,
+ "attention_mask": attention_mask,
+ }
+ else:
+ if attention_mask is not None and attention_mask.dtype != torch.bool:
+ logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool")
+ attention_mask = None
+ if attention_mask is None:
+ attention_mask = self.get_masks(input_ids, device=input_ids.device)
+ if position_ids is None:
+ position_ids = self.get_position_ids(
+ input_ids, device=input_ids.device, mask_positions=mask_positions, use_gmasks=use_gmasks
+ )
+
+ return {
+ "input_ids": input_ids,
+ "past_key_values": past,
+ "position_ids": position_ids,
+ "attention_mask": attention_mask,
+ }
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = transformer_outputs[0]
+
+ lm_logits = self.lm_head(hidden_states).permute(1, 0, 2).contiguous()
+
+ loss = None
+ if labels is not None:
+ lm_logits = lm_logits.to(torch.float32)
+
+ # Shift so that tokens < n predict n
+ shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+ lm_logits = lm_logits.to(hidden_states.dtype)
+ loss = loss.to(hidden_states.dtype)
+
+ if not return_dict:
+ output = (lm_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ @staticmethod
+ def _reorder_cache(
+ past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
+ """
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
+ beam_idx at every generation step.
+
+ Output shares the same memory storage as `past`.
+ """
+ return tuple(
+ (
+ layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
+ layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
+ )
+ for layer_past in past
+ )
+
+ def process_response(self, response):
+ response = response.strip()
+ response = response.replace("[[训练时间]]", "2023年")
+ punkts = [
+ [",", ","],
+ ["!", "!"],
+ [":", ":"],
+ [";", ";"],
+ ["\?", "?"],
+ ]
+ for item in punkts:
+ response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
+ response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
+ return response
+
+ @torch.no_grad()
+ def chat(
+ self,
+ tokenizer,
+ query: str,
+ history: List[Tuple[str, str]] = None,
+ max_length: int = 2048,
+ num_beams=1,
+ do_sample=True,
+ top_p=0.7,
+ temperature=0.95,
+ logits_processor=None,
+ **kwargs,
+ ):
+ if history is None:
+ history = []
+ if logits_processor is None:
+ logits_processor = LogitsProcessorList()
+ logits_processor.append(InvalidScoreLogitsProcessor())
+ gen_kwargs = {
+ "max_length": max_length,
+ "num_beams": num_beams,
+ "do_sample": do_sample,
+ "top_p": top_p,
+ "temperature": temperature,
+ "logits_processor": logits_processor,
+ **kwargs,
+ }
+ if not history:
+ prompt = query
+ else:
+ prompt = ""
+ for i, (old_query, response) in enumerate(history):
+ prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
+ prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
+ inputs = tokenizer([prompt], return_tensors="pt")
+ inputs = inputs.to(self.device)
+ outputs = self.generate(**inputs, **gen_kwargs)
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
+ response = tokenizer.decode(outputs)
+ response = self.process_response(response)
+ history = history + [(query, response)]
+ return response, history
+
+ @torch.no_grad()
+ def stream_chat(
+ self,
+ tokenizer,
+ query: str,
+ history: List[Tuple[str, str]] = None,
+ max_length: int = 2048,
+ do_sample=True,
+ top_p=0.7,
+ temperature=0.95,
+ logits_processor=None,
+ **kwargs,
+ ):
+ if history is None:
+ history = []
+ if logits_processor is None:
+ logits_processor = LogitsProcessorList()
+ logits_processor.append(InvalidScoreLogitsProcessor())
+ gen_kwargs = {
+ "max_length": max_length,
+ "do_sample": do_sample,
+ "top_p": top_p,
+ "temperature": temperature,
+ "logits_processor": logits_processor,
+ **kwargs,
+ }
+ if not history:
+ prompt = query
+ else:
+ prompt = ""
+ for i, (old_query, response) in enumerate(history):
+ prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
+ prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
+ inputs = tokenizer([prompt], return_tensors="pt")
+ inputs = inputs.to(self.device)
+ for outputs in self.stream_generate(**inputs, **gen_kwargs):
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
+ response = tokenizer.decode(outputs)
+ response = self.process_response(response)
+ new_history = history + [(query, response)]
+ yield response, new_history
+
+ @torch.no_grad()
+ def stream_generate(
+ self,
+ input_ids,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
+ **kwargs,
+ ):
+ batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
+
+ if generation_config is None:
+ generation_config = self.generation_config
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(**kwargs)
+ bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
+
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
+ if has_default_max_length and generation_config.max_new_tokens is None:
+ warnings.warn(
+ f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
+ "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
+ UserWarning,
+ )
+ elif generation_config.max_new_tokens is not None:
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
+ if not has_default_max_length:
+ logger.warn(
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
+ "Please refer to the documentation for more information. "
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
+ UserWarning,
+ )
+
+ if input_ids_seq_length >= generation_config.max_length:
+ input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
+ logger.warning(
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
+ " increasing `max_new_tokens`."
+ )
+
+ # 2. Set generation parameters if not already defined
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
+
+ logits_processor = self._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_seq_length,
+ encoder_input_ids=input_ids,
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
+ logits_processor=logits_processor,
+ )
+
+ stopping_criteria = self._get_stopping_criteria(
+ generation_config=generation_config, stopping_criteria=stopping_criteria
+ )
+ logits_warper = self._get_logits_warper(generation_config)
+
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
+ scores = None
+ while True:
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
+ # forward pass to get next token
+ outputs = self(
+ **model_inputs,
+ return_dict=True,
+ output_attentions=False,
+ output_hidden_states=False,
+ )
+
+ next_token_logits = outputs.logits[:, -1, :]
+
+ # pre-process distribution
+ next_token_scores = logits_processor(input_ids, next_token_logits)
+ next_token_scores = logits_warper(input_ids, next_token_scores)
+
+ # sample
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
+ if generation_config.do_sample:
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
+ else:
+ next_tokens = torch.argmax(probs, dim=-1)
+
+ # update generated ids, model inputs, and length for next step
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
+ model_kwargs = self._update_model_kwargs_for_generation(
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
+ )
+ unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
+
+ # stop when each sentence is finished, or if we exceed the maximum length
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
+ break
+ yield input_ids
+
+ def quantize(self, bits: int, empty_init=False, **kwargs):
+ if bits == 0:
+ return
+
+ from .quantization import quantize
+
+ if self.quantized:
+ logger.info("Already quantized.")
+ return self
+
+ self.quantized = True
+
+ self.config.quantization_bit = bits
+
+ self.transformer = quantize(self.transformer, bits, empty_init=empty_init, **kwargs)
+ return self
diff --git a/applications/Chat/coati/models/deberta/__init__.py b/applications/Chat/coati/models/deberta/__init__.py
deleted file mode 100644
index b66888f34fd0b81b726b28dec39b9ed26d99c945..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/models/deberta/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from .deberta_critic import DebertaCritic
-from .deberta_rm import DebertaRM
-
-__all__ = ['DebertaCritic', 'DebertaRM']
diff --git a/applications/Chat/coati/models/deberta/deberta_critic.py b/applications/Chat/coati/models/deberta/deberta_critic.py
deleted file mode 100644
index e84c1dbd8380728a6544c50be2b82146821fb3c3..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/models/deberta/deberta_critic.py
+++ /dev/null
@@ -1,36 +0,0 @@
-from typing import Optional
-
-import torch.nn as nn
-from transformers import DebertaV2Config, DebertaV2Model
-
-from ..base import Critic
-
-
-class DebertaCritic(Critic):
- """
- Deberta Critic model.
-
- Args:
- pretrained (str): Pretrained model name or path.
- config (DebertaV2Config): Model config.
- checkpoint (bool): Enable gradient checkpointing.
- lora_rank (int): Rank of the LO-RA decomposition.
- lora_train_bias (str): LoRA bias training mode.
- """
-
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[DebertaV2Config] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
- if pretrained is not None:
- model = DebertaV2Model.from_pretrained(pretrained)
- elif config is not None:
- model = DebertaV2Model(config)
- else:
- model = DebertaV2Model(DebertaV2Config())
- if checkpoint:
- model.gradient_checkpointing_enable()
- value_head = nn.Linear(model.config.hidden_size, 1)
- super().__init__(model, value_head, lora_rank, lora_train_bias)
diff --git a/applications/Chat/coati/models/deberta/deberta_rm.py b/applications/Chat/coati/models/deberta/deberta_rm.py
deleted file mode 100644
index 2448c879ec859ebfd13afd859508e28b63427c06..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/models/deberta/deberta_rm.py
+++ /dev/null
@@ -1,37 +0,0 @@
-from typing import Optional
-
-import torch.nn as nn
-from transformers import DebertaV2Config, DebertaV2Model
-
-from ..base import RewardModel
-
-
-class DebertaRM(RewardModel):
- """
- Deberta Reward model.
-
- Args:
- pretrained (str): Pretrained model name or path.
- config (DebertaV2Config): Model config.
- checkpoint (bool): Enable gradient checkpointing.
- lora_rank (int): Rank of the LO-RA decomposition.
- lora_train_bias (str): LoRA bias training mode.
- """
-
- def __init__(self,
- pretrained: str = None,
- config: Optional[DebertaV2Config] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
- if pretrained is not None:
- model = DebertaV2Model.from_pretrained(pretrained)
- elif config is not None:
- model = DebertaV2Model(config)
- else:
- model = DebertaV2Model(DebertaV2Config())
- if checkpoint:
- model.gradient_checkpointing_enable()
- value_head = nn.Linear(model.config.hidden_size, 1)
- value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
- super().__init__(model, value_head, lora_rank, lora_train_bias)
diff --git a/applications/Chat/coati/models/generation.py b/applications/Chat/coati/models/generation.py
index f57c9458a271131a14d6b035cd039d2a5de1281b..4ab0cdc8a3eaa21afb5b342f0f12cfea7dddd0b0 100644
--- a/applications/Chat/coati/models/generation.py
+++ b/applications/Chat/coati/models/generation.py
@@ -2,7 +2,9 @@ from typing import Any, Callable, Optional
import torch
import torch.distributed as dist
-import torch.nn as nn
+from transformers import PreTrainedTokenizer
+
+from .base import Actor
try:
from transformers.generation_logits_process import (
@@ -15,9 +17,9 @@ except ImportError:
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
-def prepare_logits_processor(top_k: Optional[int] = None,
- top_p: Optional[float] = None,
- temperature: Optional[float] = None) -> LogitsProcessorList:
+def _prepare_logits_processor(
+ top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
+) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
if temperature is not None and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
@@ -36,32 +38,34 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
return unfinished_sequences.max() == 0
-def sample(model: nn.Module,
- input_ids: torch.Tensor,
- max_length: int,
- early_stopping: bool = False,
- eos_token_id: Optional[int] = None,
- pad_token_id: Optional[int] = None,
- top_k: Optional[int] = None,
- top_p: Optional[float] = None,
- temperature: Optional[float] = None,
- prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
- update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
- **model_kwargs) -> torch.Tensor:
+def _sample(
+ model: Actor,
+ input_ids: torch.Tensor,
+ max_length: int,
+ early_stopping: bool = False,
+ eos_token_id: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ temperature: Optional[float] = None,
+ prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
+ update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
+ **model_kwargs,
+) -> torch.Tensor:
if input_ids.size(1) >= max_length:
return input_ids
- logits_processor = prepare_logits_processor(top_k, top_p, temperature)
+ logits_processor = _prepare_logits_processor(top_k, top_p, temperature)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
for _ in range(input_ids.size(1), max_length):
- model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {
- 'input_ids': input_ids
- }
+ model_inputs = (
+ prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids}
+ )
outputs = model(**model_inputs)
- next_token_logits = outputs['logits'][:, -1, :]
- # pre-process distribution
+ # NOTE: this is correct only in left padding mode
+ next_token_logits = outputs["logits"][:, -1, :]
next_token_logits = logits_processor(input_ids, next_token_logits)
# sample
probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
@@ -69,8 +73,7 @@ def sample(model: nn.Module,
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
- if pad_token_id is None:
- raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
+ assert pad_token_id is not None, "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs for next step
@@ -89,20 +92,22 @@ def sample(model: nn.Module,
return input_ids
-def generate(model: nn.Module,
- input_ids: torch.Tensor,
- max_length: int,
- num_beams: int = 1,
- do_sample: bool = True,
- early_stopping: bool = False,
- eos_token_id: Optional[int] = None,
- pad_token_id: Optional[int] = None,
- top_k: Optional[int] = None,
- top_p: Optional[float] = None,
- temperature: Optional[float] = None,
- prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
- update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
- **model_kwargs) -> torch.Tensor:
+@torch.no_grad()
+def generate(
+ model: Actor,
+ input_ids: torch.Tensor,
+ tokenizer: PreTrainedTokenizer,
+ max_length: int,
+ num_beams: int = 1,
+ do_sample: bool = True,
+ early_stopping: bool = False,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ temperature: Optional[float] = None,
+ prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
+ update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
+ **model_kwargs,
+) -> torch.Tensor:
"""Generate token sequence. The returned sequence is input_ids + generated_tokens.
Args:
@@ -112,34 +117,35 @@ def generate(model: nn.Module,
num_beams (int, optional): number of beams. Defaults to 1.
do_sample (bool, optional): whether to do sample. Defaults to True.
early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
- eos_token_id (Optional[int], optional): end of sequence token id. Defaults to None.
- pad_token_id (Optional[int], optional): pad token id. Defaults to None.
top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
"""
- is_greedy_gen_mode = ((num_beams == 1) and do_sample is False)
- is_sample_gen_mode = ((num_beams == 1) and do_sample is True)
- is_beam_gen_mode = ((num_beams > 1) and do_sample is False)
+ assert tokenizer.padding_side == "left", "Current generation only supports left padding."
+ is_greedy_gen_mode = (num_beams == 1) and do_sample is False
+ is_sample_gen_mode = (num_beams == 1) and do_sample is True
+ is_beam_gen_mode = (num_beams > 1) and do_sample is False
if is_greedy_gen_mode:
# run greedy search
raise NotImplementedError
elif is_sample_gen_mode:
# run sample
- return sample(model,
- input_ids,
- max_length,
- early_stopping=early_stopping,
- eos_token_id=eos_token_id,
- pad_token_id=pad_token_id,
- top_k=top_k,
- top_p=top_p,
- temperature=temperature,
- prepare_inputs_fn=prepare_inputs_fn,
- update_model_kwargs_fn=update_model_kwargs_fn,
- **model_kwargs)
+ return _sample(
+ model,
+ input_ids,
+ max_length,
+ early_stopping=early_stopping,
+ eos_token_id=tokenizer.eos_token_id,
+ pad_token_id=tokenizer.pad_token_id,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ prepare_inputs_fn=prepare_inputs_fn,
+ update_model_kwargs_fn=update_model_kwargs_fn,
+ **model_kwargs,
+ )
elif is_beam_gen_mode:
raise NotImplementedError
else:
diff --git a/applications/Chat/coati/models/gpt/__init__.py b/applications/Chat/coati/models/gpt/__init__.py
index 63dc5ab0f5ead4450e1f9bffae874520614fa05c..823cf4a75e0d4fe8cdd9de9e37947f0bf45e5657 100644
--- a/applications/Chat/coati/models/gpt/__init__.py
+++ b/applications/Chat/coati/models/gpt/__init__.py
@@ -2,4 +2,4 @@ from .gpt_actor import GPTActor
from .gpt_critic import GPTCritic
from .gpt_rm import GPTRM
-__all__ = ['GPTActor', 'GPTCritic', 'GPTRM']
+__all__ = ["GPTActor", "GPTCritic", "GPTRM"]
diff --git a/applications/Chat/coati/models/gpt/gpt_actor.py b/applications/Chat/coati/models/gpt/gpt_actor.py
index ae9d669f1f5669dc63e829332ccdf7ebf991afe2..a7e4b9bc3e22425da707734a49a902a72368c011 100644
--- a/applications/Chat/coati/models/gpt/gpt_actor.py
+++ b/applications/Chat/coati/models/gpt/gpt_actor.py
@@ -18,13 +18,15 @@ class GPTActor(Actor):
lora_train_bias (str): Bias training strategy for the LoRa layer.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[GPT2Config] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none',
- **kwargs) -> None:
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[GPT2Config] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ **kwargs,
+ ) -> None:
if pretrained is not None:
model = GPT2LMHeadModel.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/gpt/gpt_critic.py b/applications/Chat/coati/models/gpt/gpt_critic.py
index 2e70f5f1fc9632edbd2d83ef7b0067f7accd0eda..22ab36dea276cfeefffed917b2233cacab736b11 100644
--- a/applications/Chat/coati/models/gpt/gpt_critic.py
+++ b/applications/Chat/coati/models/gpt/gpt_critic.py
@@ -14,25 +14,24 @@ class GPTCritic(Critic):
Args:
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
- checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the LO-RA decomposition.
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[GPT2Config] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none',
- **kwargs) -> None:
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[GPT2Config] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ **kwargs,
+ ) -> None:
if pretrained is not None:
model = GPT2Model.from_pretrained(pretrained)
elif config is not None:
model = GPT2Model(config)
else:
model = GPT2Model(GPT2Config())
- if checkpoint:
- model.gradient_checkpointing_enable()
+
value_head = nn.Linear(model.config.n_embd, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
diff --git a/applications/Chat/coati/models/gpt/gpt_rm.py b/applications/Chat/coati/models/gpt/gpt_rm.py
index 054432e1ce863a36c169ea3b7569198ab1994e6c..8edfc4008466bbfd5ab303e36877cd0d8f5d3a50 100644
--- a/applications/Chat/coati/models/gpt/gpt_rm.py
+++ b/applications/Chat/coati/models/gpt/gpt_rm.py
@@ -14,25 +14,23 @@ class GPTRM(RewardModel):
Args:
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
- checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[GPT2Config] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[GPT2Config] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ ) -> None:
if pretrained is not None:
model = GPT2Model.from_pretrained(pretrained)
elif config is not None:
model = GPT2Model(config)
else:
model = GPT2Model(GPT2Config())
- if checkpoint:
- model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.n_embd, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.n_embd + 1))
diff --git a/applications/Chat/coati/models/llama/__init__.py b/applications/Chat/coati/models/llama/__init__.py
index 9b2a024afdb28d3336d801895da3f26b01f28c56..c87d732538a9f0357902368c5d5a90b9174e1ead 100644
--- a/applications/Chat/coati/models/llama/__init__.py
+++ b/applications/Chat/coati/models/llama/__init__.py
@@ -2,4 +2,4 @@ from .llama_actor import LlamaActor
from .llama_critic import LlamaCritic
from .llama_rm import LlamaRM
-__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM']
+__all__ = ["LlamaActor", "LlamaCritic", "LlamaRM"]
diff --git a/applications/Chat/coati/models/llama/llama_actor.py b/applications/Chat/coati/models/llama/llama_actor.py
index 2c7adb390d8bea055e9fc84f75dc66658a8fe0e3..f1d9406835ca889138a465f28eb8fbb997226005 100644
--- a/applications/Chat/coati/models/llama/llama_actor.py
+++ b/applications/Chat/coati/models/llama/llama_actor.py
@@ -1,7 +1,6 @@
from typing import Optional
-import torch
-from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
+from transformers import LlamaConfig, LlamaForCausalLM
from ..base import Actor
@@ -18,13 +17,14 @@ class LlamaActor(Actor):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[LlamaConfig] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
-
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[LlamaConfig] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ ) -> None:
if pretrained is not None:
model = LlamaForCausalLM.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/llama/llama_critic.py b/applications/Chat/coati/models/llama/llama_critic.py
index dd9e5e7bfa1ae6da1c9c09789ae0eda602a35720..000dce17ccf013d510f98ac1ac24885bf5af757c 100644
--- a/applications/Chat/coati/models/llama/llama_critic.py
+++ b/applications/Chat/coati/models/llama/llama_critic.py
@@ -13,19 +13,18 @@ class LlamaCritic(Critic):
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
- checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[LlamaConfig] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none',
- **kwargs) -> None:
-
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[LlamaConfig] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ **kwargs,
+ ) -> None:
if pretrained is not None:
model = LlamaModel.from_pretrained(pretrained)
elif config is not None:
@@ -33,9 +32,5 @@ class LlamaCritic(Critic):
else:
model = LlamaModel(LlamaConfig())
- if checkpoint:
- model.gradient_checkpointing_enable()
-
value_head = nn.Linear(model.config.hidden_size, 1)
-
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
diff --git a/applications/Chat/coati/models/llama/llama_rm.py b/applications/Chat/coati/models/llama/llama_rm.py
index f936019d62d28bac0bf41161b3f4aaad26e2bbf0..43bc9e638dc7944191cd903cca1134a15dbfe3dc 100644
--- a/applications/Chat/coati/models/llama/llama_rm.py
+++ b/applications/Chat/coati/models/llama/llama_rm.py
@@ -1,7 +1,7 @@
from typing import Optional
import torch.nn as nn
-from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
+from transformers import LlamaConfig, LlamaModel
from ..base import RewardModel
@@ -13,18 +13,17 @@ class LlamaRM(RewardModel):
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
- checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[LlamaConfig] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
-
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[LlamaConfig] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ ) -> None:
if pretrained is not None:
model = LlamaModel.from_pretrained(pretrained)
elif config is not None:
@@ -32,8 +31,6 @@ class LlamaRM(RewardModel):
else:
model = LlamaModel(LlamaConfig())
- if checkpoint:
- model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
diff --git a/applications/Chat/coati/models/lora.py b/applications/Chat/coati/models/lora.py
index 0533a60dc53266d592fe2b5808eb637ac64db875..e9bd7b2ed8f0608e9fd6c4c26a39340d087a454f 100644
--- a/applications/Chat/coati/models/lora.py
+++ b/applications/Chat/coati/models/lora.py
@@ -1,4 +1,6 @@
+import dataclasses
import math
+import warnings
from typing import Optional
import loralib as lora
@@ -7,9 +9,16 @@ import torch.nn as nn
import torch.nn.functional as F
+@dataclasses.dataclass
+class LoRAManager:
+ merge_weights: bool = False
+
+
+LORA_MANAGER = LoRAManager()
+
+
class LoraLinear(lora.LoRALayer, nn.Module):
- """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.
- """
+ """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
def __init__(
self,
@@ -17,16 +26,12 @@ class LoraLinear(lora.LoRALayer, nn.Module):
bias: Optional[nn.Parameter],
r: int = 0,
lora_alpha: int = 1,
- lora_dropout: float = 0.,
- fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
- merge_weights: bool = True,
+ lora_dropout: float = 0.0,
+ # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
+ fan_in_fan_out: bool = False,
):
nn.Module.__init__(self)
- lora.LoRALayer.__init__(self,
- r=r,
- lora_alpha=lora_alpha,
- lora_dropout=lora_dropout,
- merge_weights=merge_weights)
+ lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
self.weight = weight
self.bias = bias
@@ -47,39 +52,42 @@ class LoraLinear(lora.LoRALayer, nn.Module):
self.weight.data = self.weight.data.T
def reset_parameters(self):
- if hasattr(self, 'lora_A'):
- # initialize A the same way as the default for nn.Linear and B to zero
+ if hasattr(self, "lora_A"):
+ # Initialize A with the default values for nn.Linear and set B to zero.
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True):
-
- def T(w):
- return w.T if self.fan_in_fan_out else w
-
- nn.Module.train(self, mode)
- if self.merge_weights and self.merged:
- # Make sure that the weights are not merged
- if self.r > 0:
- self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
- self.merged = False
-
- def eval(self):
-
def T(w):
return w.T if self.fan_in_fan_out else w
- nn.Module.eval(self)
- if self.merge_weights and not self.merged:
- # Merge the weights and mark it
- if self.r > 0:
- self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
- delattr(self, 'lora_A')
- delattr(self, 'lora_B')
- self.merged = True
+ self.training = mode
+ if LORA_MANAGER.merge_weights:
+ if mode and self.merged:
+ warnings.warn("Invoke module.train() would unmerge LoRA weights.")
+ raise NotImplementedError("LoRA unmerge is not tested.")
+ # Make sure that the weights are not merged
+ if self.r > 0:
+ if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
+ # FIXME(csric): temporary fix
+ self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
+ self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
+ self.reset_parameters()
+ else:
+ self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
+ self.merged = False
+ elif not mode and not self.merged:
+ warnings.warn("Invoke module.eval() would merge LoRA weights.")
+ # Merge the weights and mark it
+ if self.r > 0:
+ self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
+ delattr(self, "lora_A")
+ delattr(self, "lora_B")
+ self.merged = True
+
+ return self
def forward(self, x: torch.Tensor):
-
def T(w):
return w.T if self.fan_in_fan_out else w
@@ -92,21 +100,23 @@ class LoraLinear(lora.LoRALayer, nn.Module):
return F.linear(x, T(self.weight), bias=self.bias)
-def lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
- assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})'
- lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
+def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
+ assert (
+ lora_rank <= linear.in_features
+ ), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
+ lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank)
return lora_linear
-def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
+def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
for name, child in module.named_children():
if isinstance(child, nn.Linear):
- setattr(module, name, lora_linear_wrapper(child, lora_rank))
+ setattr(module, name, _lora_linear_wrapper(child, lora_rank))
else:
- convert_to_lora_recursively(child, lora_rank)
+ _convert_to_lora_recursively(child, lora_rank)
-def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module:
+def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = "none") -> nn.Module:
"""Convert a torch.nn.Module to a LoRA module.
Args:
@@ -118,7 +128,7 @@ def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: s
"""
if lora_rank <= 0:
return module
- convert_to_lora_recursively(module, lora_rank)
+ _convert_to_lora_recursively(module, lora_rank)
lora.mark_only_lora_as_trainable(module, lora_train_bias)
return module
@@ -134,7 +144,7 @@ class LoRAModule(nn.Module):
Defaults to 'none'.
"""
- def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
+ def __init__(self, lora_rank: int = 0, lora_train_bias: str = "none") -> None:
super().__init__()
self.lora_rank = lora_rank
self.lora_train_bias = lora_train_bias
diff --git a/applications/Chat/coati/models/loss.py b/applications/Chat/coati/models/loss.py
index 926c6e2a4e4131ece0693630bce4894e8cbec5f0..687bd0f7bfe76fc8865569f742fd31b3c8d4f868 100644
--- a/applications/Chat/coati/models/loss.py
+++ b/applications/Chat/coati/models/loss.py
@@ -13,6 +13,7 @@ class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
+ # NOTE: default ignore_index is -100, which is equal to IGNORE_INDEX in sft_dataset.py
self.loss = nn.CrossEntropyLoss()
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
@@ -31,11 +32,13 @@ class PolicyLoss(nn.Module):
super().__init__()
self.clip_eps = clip_eps
- def forward(self,
- log_probs: torch.Tensor,
- old_log_probs: torch.Tensor,
- advantages: torch.Tensor,
- action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def forward(
+ self,
+ log_probs: torch.Tensor,
+ old_log_probs: torch.Tensor,
+ advantages: torch.Tensor,
+ action_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
ratio = (log_probs - old_log_probs).exp()
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
@@ -55,44 +58,21 @@ class ValueLoss(nn.Module):
super().__init__()
self.clip_eps = clip_eps
- def forward(self,
- values: torch.Tensor,
- old_values: torch.Tensor,
- reward: torch.Tensor,
- action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def forward(
+ self,
+ values: torch.Tensor,
+ old_values: torch.Tensor,
+ reward: torch.Tensor,
+ action_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
- surr1 = (values_clipped - reward)**2
- surr2 = (values - reward)**2
+ surr1 = (values_clipped - reward) ** 2
+ surr2 = (values - reward) ** 2
loss = torch.max(surr1, surr2)
loss = loss.mean()
return 0.5 * loss
-class PPOPtxActorLoss(nn.Module):
- """
- To Do:
-
- PPO-ptx Actor Loss
- """
-
- def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None:
- super().__init__()
- self.pretrain_coef = pretrain_coef
- self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps)
- self.pretrain_loss_fn = pretrain_loss_fn
-
- def forward(self,
- log_probs: torch.Tensor,
- old_log_probs: torch.Tensor,
- advantages: torch.Tensor,
- lm_logits: torch.Tensor,
- lm_input_ids: torch.Tensor,
- action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
- policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask)
- lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids)
- return policy_loss + self.pretrain_coef * lm_loss
-
-
class LogSigLoss(nn.Module):
"""
Pairwise Loss for Reward Model
diff --git a/applications/Chat/coati/models/opt/__init__.py b/applications/Chat/coati/models/opt/__init__.py
index 334f4df0032a1da8f8a3d23c9c987448c9dbc0d8..e37d6e45c8fc9deee35841878857b7d29fd7eed3 100644
--- a/applications/Chat/coati/models/opt/__init__.py
+++ b/applications/Chat/coati/models/opt/__init__.py
@@ -2,4 +2,4 @@ from .opt_actor import OPTActor
from .opt_critic import OPTCritic
from .opt_rm import OPTRM
-__all__ = ['OPTActor', 'OPTCritic', 'OPTRM']
+__all__ = ["OPTActor", "OPTCritic", "OPTRM"]
diff --git a/applications/Chat/coati/models/opt/opt_actor.py b/applications/Chat/coati/models/opt/opt_actor.py
index c14e4377ffb2b00983f4ea4a1af7c9931e9d1cb9..cd8908e13fb8cc62de664023d491eb6090dc74b6 100644
--- a/applications/Chat/coati/models/opt/opt_actor.py
+++ b/applications/Chat/coati/models/opt/opt_actor.py
@@ -18,12 +18,14 @@ class OPTActor(Actor):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[OPTConfig] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[OPTConfig] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ ) -> None:
if pretrained is not None:
model = OPTForCausalLM.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/coati/models/opt/opt_critic.py b/applications/Chat/coati/models/opt/opt_critic.py
index fcfebd8a8b031785d93d709aead85d5aa6f30d08..f37d28812c27e184431245894ee506fe725696db 100644
--- a/applications/Chat/coati/models/opt/opt_critic.py
+++ b/applications/Chat/coati/models/opt/opt_critic.py
@@ -14,25 +14,24 @@ class OPTCritic(Critic):
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
- checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[OPTConfig] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none',
- **kwargs) -> None:
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[OPTConfig] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ **kwargs,
+ ) -> None:
if pretrained is not None:
model = OPTModel.from_pretrained(pretrained)
elif config is not None:
model = OPTModel(config)
else:
model = OPTModel(OPTConfig())
- if checkpoint:
- model.gradient_checkpointing_enable()
+
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
diff --git a/applications/Chat/coati/models/opt/opt_rm.py b/applications/Chat/coati/models/opt/opt_rm.py
index 50fc0dee8568f86d8d63e712104306e8c1e013b8..893708344ad4c5e0503b64f3118999c4e3aaa91b 100644
--- a/applications/Chat/coati/models/opt/opt_rm.py
+++ b/applications/Chat/coati/models/opt/opt_rm.py
@@ -13,25 +13,23 @@ class OPTRM(RewardModel):
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
- checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[OPTConfig] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
+ def __init__(
+ self,
+ pretrained: Optional[str] = None,
+ config: Optional[OPTConfig] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = "none",
+ ) -> None:
if pretrained is not None:
model = OPTModel.from_pretrained(pretrained)
elif config is not None:
model = OPTModel(config)
else:
model = OPTModel(OPTConfig())
- if checkpoint:
- model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.word_embed_proj_dim + 1))
diff --git a/applications/Chat/coati/models/roberta/__init__.py b/applications/Chat/coati/models/roberta/__init__.py
deleted file mode 100644
index 0f4a8de067b1695c4abbc880f2e3dc1f6510ad26..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/models/roberta/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from .roberta_actor import RoBERTaActor
-from .roberta_critic import RoBERTaCritic
-from .roberta_rm import RoBERTaRM
-
-__all__ = ['RoBERTaActor', 'RoBERTaCritic', 'RoBERTaRM']
\ No newline at end of file
diff --git a/applications/Chat/coati/models/roberta/roberta_actor.py b/applications/Chat/coati/models/roberta/roberta_actor.py
deleted file mode 100644
index e35fa6eb19a8053a2ea5cec3f2a07a2dd3c80735..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/models/roberta/roberta_actor.py
+++ /dev/null
@@ -1,35 +0,0 @@
-from typing import Optional
-
-from transformers.models.roberta.configuration_roberta import RobertaConfig
-from transformers.models.roberta.modeling_roberta import RobertaForCausalLM
-
-from ..base import Actor
-
-class RoBERTaActor(Actor):
- """
- RoBERTa Actor model.
-
- Args:
- pretrained (str): Pretrained model name or path.
- config (RoBERTaConfig): Model config.
- checkpoint (bool): Enable gradient checkpointing.
- lora_rank (int): Rank of the low-rank approximation.
- lora_train_bias (str): LoRA bias training mode.
- """
-
-
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[RobertaConfig] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
- if pretrained is not None:
- model = RobertaForCausalLM.from_pretrained(pretrained)
- elif config is not None:
- model = RobertaForCausalLM(config)
- else:
- model = RobertaForCausalLM(RobertaConfig())
- if checkpoint:
- model.gradient_checkpointing_enable()
- super().__init__(model, lora_rank, lora_train_bias)
diff --git a/applications/Chat/coati/models/roberta/roberta_critic.py b/applications/Chat/coati/models/roberta/roberta_critic.py
deleted file mode 100644
index c8dc0d9e14f2813907a6345ffbf93784e9b8528c..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/models/roberta/roberta_critic.py
+++ /dev/null
@@ -1,38 +0,0 @@
-from typing import Optional
-
-import torch.nn as nn
-from transformers.models.roberta.configuration_roberta import RobertaConfig
-from transformers.models.roberta.modeling_roberta import RobertaModel
-
-from ..base import Critic
-
-
-class RoBERTaCritic(Critic):
- """
- RoBERTa Critic model.
-
- Args:
- pretrained (str): Pretrained model name or path.
- config (RoBERTa Config): Model config.
- checkpoint (bool): Enable gradient checkpointing.
- lora_rank (int): Rank of the low-rank approximation.
- lora_train_bias (str): LoRA bias training mode.
- """
-
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[RobertaConfig] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none',
- **kwargs) -> None:
- if pretrained is not None:
- model = RobertaModel.from_pretrained(pretrained, add_pooling_layer=False)
- elif config is not None:
- model = RobertaModel(config)
- else:
- model = RobertaModel(RobertaConfig())
- if checkpoint:
- model.gradient_checkpointing_enable()
- value_head = nn.Linear(model.config.hidden_size, 1)
- super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
diff --git a/applications/Chat/coati/models/roberta/roberta_rm.py b/applications/Chat/coati/models/roberta/roberta_rm.py
deleted file mode 100644
index 77075052978b56d9bde336b3ac0473a343f6c332..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/models/roberta/roberta_rm.py
+++ /dev/null
@@ -1,39 +0,0 @@
-from typing import Optional
-
-import torch.nn as nn
-from transformers import RobertaConfig, RobertaModel
-
-
-from ..base import RewardModel
-
-
-class RoBERTaRM(RewardModel):
- """
- RoBERTa Reward model.
-
- Args:
- pretrained (str): Pretrained model name or path.
- config (RoBERTaConfig): Model config.
- checkpoint (bool): Enable gradient checkpointing.
- lora_rank (int): Rank of the low-rank approximation.
- lora_train_bias (str): LoRA bias training mode.
- """
-
- def __init__(self,
- pretrained: Optional[str] = None,
- config: Optional[RobertaConfig] = None,
- checkpoint: bool = False,
- lora_rank: int = 0,
- lora_train_bias: str = 'none') -> None:
- if pretrained is not None:
- model = RobertaModel.from_pretrained(pretrained, add_pooling_layer=False)
- elif config is not None:
- model = RobertaModel(config)
- else:
- model = RobertaModel(RobertaConfig())
- if checkpoint:
- model.gradient_checkpointing_enable()
-
- value_head = nn.Linear(model.config.hidden_size, 1)
- value_head.weight.data.normal_(mean=0.0, std=1/(model.config.hidden_size + 1))
- super().__init__(model, value_head, lora_rank, lora_train_bias)
\ No newline at end of file
diff --git a/applications/Chat/coati/models/utils.py b/applications/Chat/coati/models/utils.py
index 0ff13181fcd2f9d36aa413c0ff4a3f2b04a1ca0d..1aaef16620d2b18bb452b1083d0a988190840116 100644
--- a/applications/Chat/coati/models/utils.py
+++ b/applications/Chat/coati/models/utils.py
@@ -1,14 +1,12 @@
from typing import Optional, Union
-import loralib as lora
import torch
-import torch.nn as nn
import torch.nn.functional as F
-def compute_approx_kl(log_probs: torch.Tensor,
- log_probs_base: torch.Tensor,
- action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+def _compute_approx_kl(
+ log_probs: torch.Tensor, log_probs_base: torch.Tensor, action_mask: Optional[torch.Tensor] = None
+) -> torch.Tensor:
"""
Compute the approximate KL divergence between two distributions.
Schulman blog: http://joschu.net/blog/kl-approx.html
@@ -19,7 +17,7 @@ def compute_approx_kl(log_probs: torch.Tensor,
action_mask: Mask for actions.
"""
- log_ratio = log_probs - log_probs_base
+ log_ratio = log_probs_base - log_probs
approx_kl = (log_ratio.exp() - 1) - log_ratio
if action_mask is not None:
approx_kl = masked_mean(approx_kl, action_mask, dim=1)
@@ -28,65 +26,44 @@ def compute_approx_kl(log_probs: torch.Tensor,
return approx_kl
-def compute_reward(r: Union[torch.Tensor, float],
- kl_coef: float,
- log_probs: torch.Tensor,
- log_probs_base: torch.Tensor,
- action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+def compute_reward(
+ r: Union[torch.Tensor, float],
+ kl_coef: float,
+ log_probs: torch.Tensor,
+ log_probs_base: torch.Tensor,
+ action_mask: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
if kl_coef <= 0.0:
return r
- kl = compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
+ kl = _compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
reward = r - kl_coef * kl
return reward
-def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
+def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
log_probs = F.log_softmax(logits, dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
return log_probs_labels.squeeze(-1)
+def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
+ """Calculate action log probs.
+
+ Args:
+ output (torch.Tensor): Output tensor of Actor.forward.logits.
+ sequences (torch.LongTensor): Input sequences.
+ num_actions (int): Number of actions.
+
+ Returns:
+ torch.Tensor: Action log probs.
+ """
+ log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
+ return log_probs[:, -num_actions:]
+
+
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
tensor = tensor * mask
tensor = tensor.sum(dim=dim)
mask_sum = mask.sum(dim=dim)
mean = tensor / (mask_sum + 1e-8)
return mean
-
-
-def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor:
- tensor = tensor * mask
- mean = masked_mean(tensor, mask, dim=dim)
- mean_centered = tensor - mean
- var = masked_mean(mean_centered**2, mask, dim=dim)
- return mean_centered * var.clamp(min=eps).rsqrt()
-
-
-def normalize(tensor: torch.Tensor, dim: int = 0, eps: float = 1e-8) -> torch.Tensor:
- mean = tensor.mean(dim)
- mean_centered = tensor - mean
- var = (mean_centered**2).mean(dim)
- norm = mean_centered * var.clamp(min=eps).rsqrt()
- return norm
-
-
-def convert_to_lora(model: nn.Module,
- input_size: int,
- output_size: int,
- lora_rank: int = 16,
- lora_alpha: int = 1,
- lora_dropout: float = 0.,
- fan_in_fan_out: bool = False,
- merge_weights: bool = True):
- if lora_rank > min(input_size, output_size):
- raise ValueError(f"LoRA rank {lora_rank} must be less or equal than {min(input_size, output_size)}")
-
- for name, module in model.named_modules():
- if isinstance(module, nn.Linear):
- module._modules[name] = lora.Linear(input_size,
- output_size,
- r=lora_rank,
- lora_alpha=lora_alpha,
- lora_dropout=lora_dropout,
- fan_in_fan_out=fan_in_fan_out,
- merge_weights=merge_weights)
diff --git a/applications/Chat/coati/quant/__init__.py b/applications/Chat/coati/quant/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1765b8091bc3c599c493dcfdf93e8d30c5b340cb
--- /dev/null
+++ b/applications/Chat/coati/quant/__init__.py
@@ -0,0 +1,7 @@
+from .llama_gptq import load_quant as llama_load_quant
+from .utils import low_resource_init
+
+__all__ = [
+ "llama_load_quant",
+ "low_resource_init",
+]
diff --git a/applications/Chat/coati/quant/llama_gptq/__init__.py b/applications/Chat/coati/quant/llama_gptq/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..51d5233586ad91a8d092c7ee3095e608385bce8d
--- /dev/null
+++ b/applications/Chat/coati/quant/llama_gptq/__init__.py
@@ -0,0 +1,5 @@
+from .loader import load_quant
+
+__all__ = [
+ "load_quant",
+]
diff --git a/applications/Chat/coati/quant/llama_gptq/loader.py b/applications/Chat/coati/quant/llama_gptq/loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..50486337a7ab55e148ad2c1530dca3491669b587
--- /dev/null
+++ b/applications/Chat/coati/quant/llama_gptq/loader.py
@@ -0,0 +1,27 @@
+import torch
+import torch.nn as nn
+
+from .model_utils import find_layers
+from .quant import make_quant
+
+
+def load_quant(model: nn.Module, checkpoint: str, wbits: int, groupsize: int):
+ model = model.eval()
+ layers = find_layers(model)
+
+ # ignore lm head
+ layers = find_layers(model)
+ for name in ["lm_head"]:
+ if name in layers:
+ del layers[name]
+
+ make_quant(model, layers, wbits, groupsize)
+
+ if checkpoint.endswith(".safetensors"):
+ from safetensors.torch import load_file as safe_load
+
+ model.load_state_dict(safe_load(checkpoint))
+ else:
+ model.load_state_dict(torch.load(checkpoint))
+
+ return model
diff --git a/applications/Chat/coati/quant/llama_gptq/model_utils.py b/applications/Chat/coati/quant/llama_gptq/model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..18e4e476150019b231781fea059b9a4d8232bbab
--- /dev/null
+++ b/applications/Chat/coati/quant/llama_gptq/model_utils.py
@@ -0,0 +1,12 @@
+# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/modelutils.py
+
+import torch.nn as nn
+
+
+def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""):
+ if type(module) in layers:
+ return {name: module}
+ res = {}
+ for name1, child in module.named_children():
+ res.update(find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1))
+ return res
diff --git a/applications/Chat/inference/llama_gptq/quant.py b/applications/Chat/coati/quant/llama_gptq/quant.py
similarity index 90%
rename from applications/Chat/inference/llama_gptq/quant.py
rename to applications/Chat/coati/quant/llama_gptq/quant.py
index f7d5b7ce4bd8217bf246abbef0736c78be3869a6..5a7e2e72dfc55faa9c1e9172e43a6a7e91fb228f 100644
--- a/applications/Chat/inference/llama_gptq/quant.py
+++ b/applications/Chat/coati/quant/llama_gptq/quant.py
@@ -13,14 +13,13 @@ def quantize(x, scale, zero, maxq):
class Quantizer(nn.Module):
-
def __init__(self, shape=1):
super(Quantizer, self).__init__()
- self.register_buffer('maxq', torch.tensor(0))
- self.register_buffer('scale', torch.zeros(shape))
- self.register_buffer('zero', torch.zeros(shape))
+ self.register_buffer("maxq", torch.tensor(0))
+ self.register_buffer("scale", torch.zeros(shape))
+ self.register_buffer("zero", torch.zeros(shape))
- def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8):
+ def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=0.8):
self.maxq = torch.tensor(2**bits - 1)
self.perchannel = perchannel
self.sym = sym
@@ -68,7 +67,7 @@ class Quantizer(nn.Module):
self.zero = torch.round(-xmin / self.scale)
if self.mse:
- best = torch.full([x.shape[0]], float('inf'), device=dev)
+ best = torch.full([x.shape[0]], float("inf"), device=dev)
for i in range(int(self.maxshrink * self.grid)):
p = 1 - i / self.grid
xmin1 = p * xmin
@@ -123,13 +122,12 @@ class Quantizer(nn.Module):
try:
import quant_cuda
except:
- print('CUDA extension not installed.')
+ print("CUDA extension not installed.")
# Assumes layer is perfectly divisible into 256 * 256 blocks
class QuantLinear(nn.Module):
-
def __init__(self, bits, groupsize, infeatures, outfeatures):
super().__init__()
if bits not in [2, 3, 4, 8]:
@@ -142,11 +140,11 @@ class QuantLinear(nn.Module):
groupsize = groupsize if groupsize != -1 else infeatures
self.groupsize = groupsize
self.register_buffer(
- 'qzeros', torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)),
- dtype=torch.int))
- self.register_buffer('scales', torch.zeros((math.ceil(infeatures / groupsize), outfeatures)))
- self.register_buffer('bias', torch.zeros(outfeatures))
- self.register_buffer('qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int))
+ "qzeros", torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)), dtype=torch.int)
+ )
+ self.register_buffer("scales", torch.zeros((math.ceil(infeatures / groupsize), outfeatures)))
+ self.register_buffer("bias", torch.zeros(outfeatures))
+ self.register_buffer("qweight", torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int))
self._initialized_quant_state = False
def pack(self, linear, scales, zeros):
@@ -161,8 +159,10 @@ class QuantLinear(nn.Module):
for idx in range(self.infeatures):
g_idx = idx // self.groupsize
intweight.append(
- torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:,
- None])
+ torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[
+ :, None
+ ]
+ )
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
@@ -271,13 +271,13 @@ class QuantLinear(nn.Module):
return y.reshape(outshape)
-def make_quant(module, names, bits, groupsize, name=''):
+def make_quant(module, names, bits, groupsize, name=""):
if isinstance(module, QuantLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
- name1 = name + '.' + attr if name != '' else attr
+ name1 = name + "." + attr if name != "" else attr
if name1 in names:
setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features))
for name1, child in module.named_children():
- make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
+ make_quant(child, names, bits, groupsize, name + "." + name1 if name != "" else name1)
diff --git a/applications/Chat/coati/quant/utils.py b/applications/Chat/coati/quant/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d102bb30f52dff6e7996a2d4c1d37f5a4c6d0815
--- /dev/null
+++ b/applications/Chat/coati/quant/utils.py
@@ -0,0 +1,27 @@
+from contextlib import contextmanager
+
+import torch
+
+
+def _noop(*args, **kwargs):
+ pass
+
+
+@contextmanager
+def low_resource_init():
+ """This context manager disables weight initialization and sets the default float dtype to half."""
+ old_kaiming_uniform_ = torch.nn.init.kaiming_uniform_
+ old_uniform_ = torch.nn.init.uniform_
+ old_normal_ = torch.nn.init.normal_
+ dtype = torch.get_default_dtype()
+ try:
+ torch.nn.init.kaiming_uniform_ = _noop
+ torch.nn.init.uniform_ = _noop
+ torch.nn.init.normal_ = _noop
+ torch.set_default_dtype(torch.half)
+ yield
+ finally:
+ torch.nn.init.kaiming_uniform_ = old_kaiming_uniform_
+ torch.nn.init.uniform_ = old_uniform_
+ torch.nn.init.normal_ = old_normal_
+ torch.set_default_dtype(dtype)
diff --git a/applications/Chat/coati/ray/README.md b/applications/Chat/coati/ray/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..79b1db3478279cd55a9a70245f445c75420c6b72
--- /dev/null
+++ b/applications/Chat/coati/ray/README.md
@@ -0,0 +1,175 @@
+:warning: **This content may be outdated since the major update of Colossal Chat. We will update this content soon.**
+
+# Distributed PPO Training on Stage 3
+
+## Detach Experience Makers and Trainers
+
+We can completely separate the trainers and makers.
+
+
+
+
+
+- The experience maker performs inference, produces experience, and remotely delivers it to the trainer (1).
+- The trainer consumes experience to train models, and periodically transmits new model parameters to the maker (2.1, 2.2).
+- Using an experience buffer to overlap transmission and computing.
+
+In this manner, each node will work continuously without model idle time, and different optimization strategies can be applied for inference and training to meet the needs of speed or storage. It is also helpful for scalability.
+
+`DetachedPPOTrainer` and `ExperienceMakerHolder` are Ray Actors (distinguished from Actor Model), representing Trainer and Experience Maker on the graph above, respectively.
+
+[More about Ray Core](https://docs.ray.io/en/latest/ray-core/walkthrough.html)
+
+## Usage
+
+See examples at `ColossalAI/application/Chat/examples/ray`
+
+### Setup Makers
+
+- define makers' environment variables :
+
+ ```python
+ env_info_makers = [{
+ 'local_rank': '0',
+ 'rank': str(rank),
+ 'world_size': str(num_makers),
+ 'master_port': maker_port,
+ 'master_addr': master_addr
+ } for rank in range(num_makers)]
+
+ ```
+
+- define maker models :
+
+ ```python
+ def model_fn():
+ actor = get_actor_from_args(...)
+ critic = get_critic_from_args(...)
+ reward_model = get_reward_model_from_args(...)
+ initial_model = get_actor_from_args(...)
+ return actor, critic, reward_model, initial_model
+
+ ```
+
+- set experience_holder_refs :
+
+ ```python
+ experience_holder_refs = [
+ ExperienceMakerHolder.options(
+ name=f"maker_{i}",
+ num_gpus=1,
+ max_concurrency=2
+ ).remote(
+ detached_trainer_name_list=[f"trainer_{x}" for x in target_trainers(...)],
+ model_fn=model_fn,
+ ...)
+ for i, env_info_maker in enumerate(env_info_makers)
+ ]
+ ```
+
+ The names in the `detached_trainer_name_list` refer to the target trainers that the maker should send experience to.
+ We set a trainer's name the same as a maker, by `.options(name="str")`. See below.
+
+### Setup Trainers
+
+- define trainers' environment variables :
+ ```python
+ env_info_trainers = [{
+ 'local_rank': '0',
+ 'rank': str(rank),
+ 'world_size': str(num_trainers),
+ 'master_port': trainer_port,
+ 'master_addr': master_addr
+ } for rank in range(num_trainers)]
+ ```
+- define trainer models :
+
+ ```python
+ def trainer_model_fn():
+ actor = get_actor_from_args(...)
+ critic = get_critic_from_args(...)
+ return actor, critic
+ ```
+
+- set trainer_refs :
+ ```python
+ trainer_refs = [
+ DetachedPPOTrainer.options(
+ name=f"trainer{i}",
+ num_gpus=1,
+ max_concurrency=2
+ ).remote(
+ experience_maker_holder_name_list=[f"maker{x}" for x in target_makers(...)],
+ model_fn = trainer_model_fn(),
+ ...)
+ for i, env_info_trainer in enumerate(env_info_trainers)
+ ]
+ ```
+ The names in `experience_maker_holder_name_list` refer to the target makers that the trainer should send updated models to.
+ By setting `detached_trainer_name_list` and `experience_maker_holder_name_list`, we can customize the transmission graph.
+
+### Launch Jobs
+
+- define data_loader :
+
+ ```python
+ def data_loader_fn():
+ return = torch.utils.data.DataLoader(dataset=dataset)
+
+ ```
+
+- launch makers :
+
+ ```python
+ wait_tasks = []
+ for experience_holder_ref in experience_holder_refs:
+ wait_tasks.append(
+ experience_holder_ref.workingloop.remote(data_loader_fn(),
+ num_steps=experience_steps))
+
+ ```
+
+- launch trainers :
+
+ ```python
+ for trainer_ref in trainer_refs:
+ wait_tasks.append(trainer_ref.fit.remote(total_steps, update_steps, train_epochs))
+ ```
+
+- wait for done :
+ ```python
+ ray.get(wait_tasks)
+ ```
+
+## Flexible Structure
+
+We can deploy different strategies to makers and trainers. Here are some notions.
+
+### 2 Makers 1 Trainer
+
+
+
+
+
+### 2 Makers 2 Trainer
+
+
+
+
+
+### Maker Inference Quantization
+
+
+
+
+
+### Tensor Parallel
+
+
+
+
+
+## TODO
+
+- [ ] Support LoRA
+- [ ] Support TP & PP
diff --git a/applications/Chat/coati/ray/__init__.py b/applications/Chat/coati/ray/__init__.py
index 5802c05bc03feb3b755754de7385a2141e547673..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644
--- a/applications/Chat/coati/ray/__init__.py
+++ b/applications/Chat/coati/ray/__init__.py
@@ -1,2 +0,0 @@
-from .src.detached_replay_buffer import DetachedReplayBuffer
-from .src.detached_trainer_ppo import DetachedPPOTrainer
diff --git a/applications/Chat/coati/ray/callbacks/__init__.py b/applications/Chat/coati/ray/callbacks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f5e488f383e3846b428ff15a8c013ea4df07663
--- /dev/null
+++ b/applications/Chat/coati/ray/callbacks/__init__.py
@@ -0,0 +1,9 @@
+from .base import MakerCallback, TrainerCallback
+from .performance_evaluator import ExperienceMakerPerformanceEvaluator, TrainerPerformanceEvaluator
+
+__all__ = [
+ "TrainerCallback",
+ "MakerCallback",
+ "ExperienceMakerPerformanceEvaluator",
+ "TrainerPerformanceEvaluator",
+]
diff --git a/applications/Chat/coati/ray/callbacks/base.py b/applications/Chat/coati/ray/callbacks/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c5bd8a67776843493a0ca4d4412196a093259d7
--- /dev/null
+++ b/applications/Chat/coati/ray/callbacks/base.py
@@ -0,0 +1,65 @@
+from abc import ABC
+
+from coati.experience_maker import Experience
+
+
+class TrainerCallback(ABC):
+ """
+ Base callback class. It defines the interface for callbacks.
+ """
+
+ def on_fit_start(self) -> None:
+ pass
+
+ def on_fit_end(self) -> None:
+ pass
+
+ def on_episode_start(self, episode: int) -> None:
+ pass
+
+ def on_episode_end(self, episode: int) -> None:
+ pass
+
+ def on_epoch_start(self, epoch: int) -> None:
+ pass
+
+ def on_epoch_end(self, epoch: int) -> None:
+ pass
+
+ def on_batch_start(self) -> None:
+ pass
+
+ def on_batch_end(self, metrics: dict, experience: Experience) -> None:
+ pass
+
+ def on_update_start(self) -> None:
+ pass
+
+ def on_update_end(self) -> None:
+ pass
+
+
+class MakerCallback(ABC):
+ def on_loop_start(self) -> None:
+ pass
+
+ def on_loop_end(self) -> None:
+ pass
+
+ def on_make_experience_start(self) -> None:
+ pass
+
+ def on_make_experience_end(self, experience: Experience) -> None:
+ pass
+
+ def on_send_start(self) -> None:
+ pass
+
+ def on_send_end(self) -> None:
+ pass
+
+ def on_batch_start(self) -> None:
+ pass
+
+ def on_batch_end(self) -> None:
+ pass
diff --git a/applications/Chat/coati/ray/callbacks/performance_evaluator.py b/applications/Chat/coati/ray/callbacks/performance_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..18798bce7dcebe449ae280c86f4a219b56f3586c
--- /dev/null
+++ b/applications/Chat/coati/ray/callbacks/performance_evaluator.py
@@ -0,0 +1,214 @@
+from time import time
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+from coati.experience_maker import Experience
+
+from .base import MakerCallback, TrainerCallback
+
+
+def get_world_size() -> int:
+ if dist.is_initialized():
+ return dist.get_world_size()
+ return 1
+
+
+def print_rank_0(*args, **kwargs) -> None:
+ if not dist.is_initialized() or dist.get_rank() == 0:
+ print(*args, **kwargs)
+
+
+@torch.no_grad()
+def all_reduce_mean(x: float, world_size: int) -> float:
+ if world_size == 1:
+ return x
+ tensor = torch.tensor([x], device=torch.cuda.current_device())
+ dist.all_reduce(tensor)
+ tensor = tensor / world_size
+ return tensor.item()
+
+
+class Timer:
+ def __init__(self) -> None:
+ self.start_time: Optional[float] = None
+ self.duration: float = 0.0
+
+ def start(self) -> None:
+ self.start_time = time()
+
+ def end(self) -> None:
+ self.duration += time() - self.start_time
+
+ def reset(self) -> None:
+ self.duration = 0.0
+
+
+class ExperienceMakerPerformanceEvaluator(MakerCallback):
+ def __init__(
+ self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int, reward_model_num_params: int
+ ) -> None:
+ super().__init__()
+ self.world_size = get_world_size()
+ self.actor_num_params = actor_num_params
+ self.critic_num_params = critic_num_params
+ self.initial_model_num_params = initial_model_num_params
+ self.reward_model_num_params = reward_model_num_params
+
+ self.batch_timer = Timer()
+ self.send_timer = Timer()
+ self.make_experience_timer = Timer()
+ self.total_samples: int = 0
+ self.make_experience_flop: int = 0
+
+ print_rank_0(
+ f"ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}"
+ )
+
+ def on_make_experience_start(self) -> None:
+ self.make_experience_timer.start()
+
+ def on_make_experience_end(self, experience: Experience) -> None:
+ self.make_experience_timer.end()
+
+ batch_size, seq_len = experience.sequences.shape
+
+ self.total_samples += batch_size
+
+ # actor generate
+ num_actions = experience.action_mask.size(1)
+ input_len = seq_len - num_actions
+ total_seq_len = (input_len + seq_len - 1) * num_actions / 2
+ self.make_experience_flop += self.actor_num_params * batch_size * total_seq_len * 2
+ # actor forward
+ self.make_experience_flop += self.actor_num_params * batch_size * seq_len * 2
+ # critic forward
+ self.make_experience_flop += self.critic_num_params * batch_size * seq_len * 2
+ # initial model forward
+ self.make_experience_flop += self.initial_model_num_params * batch_size * seq_len * 2
+ # reward model forward
+ self.make_experience_flop += self.reward_model_num_params * batch_size * seq_len * 2
+
+ def on_send_start(self) -> None:
+ self.send_timer.start()
+
+ def on_send_end(self) -> None:
+ self.send_timer.end()
+
+ def on_batch_start(self) -> None:
+ self.batch_timer.start()
+
+ def on_batch_end(self) -> None:
+ self.batch_timer.end()
+
+ def on_loop_end(self) -> None:
+ avg_make_experience_duration = all_reduce_mean(self.make_experience_timer.duration, self.world_size)
+ avg_overall_duration = all_reduce_mean(self.batch_timer.duration, self.world_size)
+ avg_send_duration = all_reduce_mean(self.send_timer.duration, self.world_size)
+
+ avg_throughput = self.total_samples * self.world_size / (avg_overall_duration + 1e-12)
+ avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
+ avg_time_per_sample = (avg_overall_duration + 1e-12) / (self.total_samples * self.world_size)
+ avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / (
+ self.total_samples * self.world_size
+ )
+ avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size)
+
+ print_rank_0(
+ "Making Experience Performance Summary:\n"
+ + f"Throughput: {avg_throughput:.3f} samples/sec\n"
+ + f"TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n"
+ + f"Sample time (overall): {avg_time_per_sample:.3f} s\n"
+ + f"Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n"
+ + f"Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n"
+ )
+
+
+class TrainerPerformanceEvaluator(TrainerCallback):
+ def __init__(
+ self,
+ actor_num_params: int,
+ critic_num_params: int,
+ enable_grad_checkpoint: bool = False,
+ ignore_first_episodes: int = 1,
+ ) -> None:
+ super().__init__()
+ self.world_size = get_world_size()
+ self.actor_num_params = actor_num_params
+ self.critic_num_params = critic_num_params
+ self.enable_grad_checkpoint = enable_grad_checkpoint
+ self.ignore_first_episodes = ignore_first_episodes
+ self.ignore_this_episode = False
+
+ self.episode_timer = Timer()
+ self.batch_timer = Timer()
+ self.update_timer = Timer()
+ self.total_samples: int = 0
+ self.learn_flop: int = 0
+
+ print_rank_0(
+ f"Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}"
+ )
+
+ def on_episode_start(self, episodes: int) -> None:
+ self.ignore_this_episode = episodes < self.ignore_first_episodes
+ if self.ignore_this_episode:
+ return
+ self.episode_timer.start()
+
+ def on_episode_end(self, episodes: int) -> None:
+ if self.ignore_this_episode:
+ return
+ self.episode_timer.end()
+
+ def on_batch_start(self) -> None:
+ if self.ignore_this_episode:
+ return
+ self.batch_timer.start()
+
+ def on_batch_end(self, metrics: dict, experience: Experience) -> None:
+ if self.ignore_this_episode:
+ return
+ self.batch_timer.end()
+
+ batch_size, seq_len = experience.sequences.shape
+
+ self.total_samples += batch_size
+
+ # actor forward-backward, 3 means forward(1) + backward(2)
+ self.learn_flop += self.actor_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
+ # critic forward-backward
+ self.learn_flop += self.critic_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
+
+ def on_update_start(self) -> None:
+ if self.ignore_this_episode:
+ return
+ self.update_timer.start()
+
+ def on_update_end(self) -> None:
+ if self.ignore_this_episode:
+ return
+ self.update_timer.end()
+
+ def on_fit_end(self) -> None:
+ if self.total_samples == 0:
+ print_rank_0("No samples are collected, skip trainer performance evaluation")
+ return
+ avg_train_duration = all_reduce_mean(self.batch_timer.duration, self.world_size)
+ avg_update_duration = all_reduce_mean(self.update_timer.duration, self.world_size)
+ avg_episode_duration = all_reduce_mean(self.episode_timer.duration, self.world_size)
+
+ avg_throughput = self.total_samples * self.world_size / (avg_episode_duration + 1e-12)
+ avg_learn_tflops = self.learn_flop / 1e12 / (avg_train_duration + 1e-12)
+ avg_time_per_sample = (avg_episode_duration + 1e-12) / (self.total_samples * self.world_size)
+ avg_train_time_per_sample = (avg_train_duration + 1e-12) / (self.total_samples * self.world_size)
+ avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size)
+
+ print_rank_0(
+ "Learning Performance Summary:\n"
+ + f"Throughput: {avg_throughput:.3f} samples/sec\n"
+ + f"TFLOPS per GPU: {avg_learn_tflops:.3f}\n"
+ + f"Sample time (overall): {avg_time_per_sample:.3f} s\n"
+ + f"Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n"
+ + f"Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n"
+ )
diff --git a/applications/Chat/coati/ray/detached_replay_buffer.py b/applications/Chat/coati/ray/detached_replay_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..92dab17292f7ba36de6c3e96d5fc8d6be17d6e89
--- /dev/null
+++ b/applications/Chat/coati/ray/detached_replay_buffer.py
@@ -0,0 +1,70 @@
+from typing import List
+
+import torch
+from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
+from coati.experience_maker.base import Experience
+
+# from torch.multiprocessing import Queue
+from ray.util.queue import Queue
+
+
+class DetachedReplayBuffer:
+ """
+ Detached replay buffer. Share Experience across workers on the same node.
+ Therefore, a trainer node is expected to have only one instance.
+ It is ExperienceMakerHolder's duty to call append(exp) method, remotely.
+
+ Args:
+ sample_batch_size: Batch size when sampling. Exp won't enqueue until they formed a batch.
+ tp_world_size: Number of workers in the same tp group
+ limit: Limit of number of experience sample BATCHs. A number <= 0 means unlimited. Defaults to 0.
+ cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True.
+ """
+
+ def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
+ self.sample_batch_size = sample_batch_size
+ self.limit = limit
+ self.items = Queue(self.limit, actor_options={"num_cpus": 1})
+ self.batch_collector: List[BufferItem] = []
+
+ @torch.no_grad()
+ def append(self, experience: Experience) -> None:
+ """
+ Expected to be called remotely.
+ """
+ items = split_experience_batch(experience)
+ self.extend(items)
+
+ @torch.no_grad()
+ def extend(self, items: List[BufferItem]) -> None:
+ """
+ Expected to be called remotely.
+ """
+ self.batch_collector.extend(items)
+ while len(self.batch_collector) >= self.sample_batch_size:
+ items = self.batch_collector[: self.sample_batch_size]
+ experience = make_experience_batch(items)
+ self.items.put(experience, block=True)
+ self.batch_collector = self.batch_collector[self.sample_batch_size :]
+
+ def clear(self) -> None:
+ # self.items.close()
+ self.items.shutdown()
+ self.items = Queue(self.limit)
+ self.worker_state = [False] * self.tp_world_size
+ self.batch_collector = []
+
+ @torch.no_grad()
+ def sample(self, worker_rank=0, to_device="cpu") -> Experience:
+ ret = self._sample_and_erase()
+ ret.to_device(to_device)
+ return ret
+
+ @torch.no_grad()
+ def _sample_and_erase(self) -> Experience:
+ ret = self.items.get(block=True)
+ return ret
+
+ def get_length(self) -> int:
+ ret = self.items.qsize()
+ return ret
diff --git a/applications/Chat/coati/ray/detached_trainer_base.py b/applications/Chat/coati/ray/detached_trainer_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcf0a472df9e19d51b6a64fea3389d5823fdad60
--- /dev/null
+++ b/applications/Chat/coati/ray/detached_trainer_base.py
@@ -0,0 +1,179 @@
+import os
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List
+
+import ray
+import torch
+from coati.experience_buffer.utils import BufferItem
+from coati.experience_maker import Experience
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from .callbacks import TrainerCallback
+from .detached_replay_buffer import DetachedReplayBuffer
+from .utils import is_rank_0
+
+
+class DetachedTrainer(ABC):
+ """
+ Base class for detached rlhf trainers.
+ 'detach' means that the experience maker is detached compared to a normal Trainer.
+ Please set name attribute during init:
+ >>> trainer = DetachedTrainer.options(..., name = "xxx", ...).remote()
+ So an ExperienceMakerHolder can reach the detached_replay_buffer by Actor's name.
+ Args:
+ detached_strategy (DetachedStrategy): the strategy to use for training
+ detached_replay_buffer_ref (ObjectRef[DetachedReplayBuffer]): the replay buffer to use for training
+ data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
+ callbacks (List[Callback], defaults to []): the callbacks to call during training process
+ generate_kwargs (dict, optional): the kwargs to use while model generating
+
+ """
+
+ def __init__(
+ self,
+ experience_maker_holder_name_list: List[str],
+ train_batch_size: int = 8,
+ buffer_limit: int = 0,
+ dataloader_pin_memory: bool = True,
+ callbacks: List[TrainerCallback] = [],
+ debug: bool = False,
+ ) -> None:
+ super().__init__()
+ self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit)
+ self.dataloader_pin_memory = dataloader_pin_memory
+ self.callbacks = callbacks
+ self.target_holder_name_list = experience_maker_holder_name_list
+ self.target_holder_list = []
+ self._is_target_holder_initialized = False
+ self._debug = debug
+
+ def update_target_holder_list(self):
+ # as the length of target_holder_list may be zero, we need to check it by a bool flag
+ if not self._is_target_holder_initialized:
+ for name in self.target_holder_name_list:
+ self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
+ self._is_target_holder_initialized = True
+
+ @abstractmethod
+ def _update_remote_makers(self, fully_update: bool = False, **kwargs):
+ pass
+
+ def sync_models_to_remote_makers(self, **kwargs):
+ self._update_remote_makers(fully_update=True, **kwargs)
+
+ @abstractmethod
+ def training_step(self, experience: Experience) -> Dict[str, Any]:
+ pass
+
+ def _learn(self, update_steps: int, train_epochs: int) -> None:
+ data = []
+ # warmup
+ pbar = tqdm(range(update_steps), desc=f"Train epoch [1/{train_epochs}]", disable=not is_rank_0())
+ self._on_epoch_start(0)
+ self._learn_epoch(pbar, data)
+ self._on_epoch_end(0)
+ # item is already a batch
+ dataloader = DataLoader(
+ data, batch_size=1, shuffle=True, pin_memory=self.dataloader_pin_memory, collate_fn=lambda x: x[0]
+ )
+ for epoch in range(1, train_epochs):
+ pbar = tqdm(dataloader, desc=f"Train epoch [{epoch + 1}/{train_epochs}]", disable=not is_rank_0())
+ self._on_epoch_start(epoch)
+ self._learn_epoch(pbar, data)
+ self._on_epoch_end(epoch)
+
+ def _learn_epoch(self, pbar: tqdm, data: List[Experience]) -> None:
+ is_warmup = len(data) == 0
+ for x in pbar:
+ if self._debug:
+ print("[trainer] training step")
+ # sample a batch and then train to avoid waiting
+ experience = x if not is_warmup else self._buffer_sample()
+ experience.to_device(torch.cuda.current_device())
+ self._on_batch_start()
+ metrics = self.training_step(experience)
+ self._on_batch_end(metrics, experience)
+
+ if self._debug:
+ print("[trainer] step over")
+ experience.to_device("cpu")
+ if is_warmup:
+ data.append(experience)
+ pbar.set_postfix(metrics)
+
+ def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None:
+ self._on_fit_start()
+ for i in tqdm(range(total_steps // update_steps), desc="Trainer", disable=not is_rank_0()):
+ self._on_episode_start(i)
+ self._learn(update_steps, train_epochs)
+ self._on_update_start()
+ self._update_remote_makers()
+ self._on_update_end()
+ self._on_episode_end(i)
+ self._on_fit_end()
+
+ @ray.method(concurrency_group="buffer_length")
+ def buffer_get_length(self):
+ # called by ExperienceMakerHolder
+ if self._debug:
+ print("[trainer] telling length")
+ return self.detached_replay_buffer.get_length()
+
+ @ray.method(concurrency_group="buffer_append")
+ def buffer_append(self, experience: Experience):
+ # called by ExperienceMakerHolder
+ if self._debug:
+ print(f"[trainer] receiving exp.")
+ self.detached_replay_buffer.append(experience)
+
+ @ray.method(concurrency_group="buffer_append")
+ def buffer_extend(self, items: List[BufferItem]):
+ # called by ExperienceMakerHolder
+ if self._debug:
+ print(f"[trainer] receiving exp.")
+ self.detached_replay_buffer.extend(items)
+
+ @ray.method(concurrency_group="buffer_sample")
+ def _buffer_sample(self):
+ return self.detached_replay_buffer.sample()
+
+ def _on_fit_start(self) -> None:
+ for callback in self.callbacks:
+ callback.on_fit_start()
+
+ def _on_fit_end(self) -> None:
+ for callback in self.callbacks:
+ callback.on_fit_end()
+
+ def _on_episode_start(self, episode: int) -> None:
+ for callback in self.callbacks:
+ callback.on_episode_start(episode)
+
+ def _on_episode_end(self, episode: int) -> None:
+ for callback in self.callbacks:
+ callback.on_episode_end(episode)
+
+ def _on_epoch_start(self, epoch: int) -> None:
+ for callback in self.callbacks:
+ callback.on_epoch_start(epoch)
+
+ def _on_epoch_end(self, epoch: int) -> None:
+ for callback in self.callbacks:
+ callback.on_epoch_end(epoch)
+
+ def _on_batch_start(self) -> None:
+ for callback in self.callbacks:
+ callback.on_batch_start()
+
+ def _on_batch_end(self, metrics: dict, experience: Experience) -> None:
+ for callback in self.callbacks:
+ callback.on_batch_end(metrics, experience)
+
+ def _on_update_start(self) -> None:
+ for callback in self.callbacks:
+ callback.on_update_start()
+
+ def _on_update_end(self) -> None:
+ for callback in self.callbacks:
+ callback.on_update_end()
diff --git a/applications/Chat/coati/ray/detached_trainer_ppo.py b/applications/Chat/coati/ray/detached_trainer_ppo.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef84a1ddba48d63c4debeb575fe76b77d75fbfae
--- /dev/null
+++ b/applications/Chat/coati/ray/detached_trainer_ppo.py
@@ -0,0 +1,191 @@
+from typing import Callable, Dict, List, Tuple
+
+import ray
+import torch
+from coati.experience_maker import Experience
+from coati.models.base import Actor, Critic
+from coati.models.loss import PolicyLoss, ValueLoss
+from coati.trainer.strategies import GeminiStrategy, LowLevelZeroStrategy, Strategy
+from torch.optim import Adam
+
+from colossalai.nn.optimizer import HybridAdam
+
+from .callbacks import TrainerCallback, TrainerPerformanceEvaluator
+from .detached_trainer_base import DetachedTrainer
+from .lora_constructor import LoRAConstructor
+from .utils import get_model_numel, get_rank, set_dist_env, state_dict_to
+
+
+@ray.remote(
+ concurrency_groups={"buffer_length": 1, "buffer_append": 1, "buffer_sample": 1, "model_io": 1, "compute": 1}
+)
+class DetachedPPOTrainer(DetachedTrainer):
+ """
+ Detached Trainer for PPO algorithm
+ Args:
+ strategy (Strategy): the strategy to use for training
+ model (str) : for actor / critic init
+ pretrained (str) : for actor / critic init
+ lora_rank (int) : for actor / critic init
+ train_batch_size (int, defaults to 8): the batch size to use for training
+ train_batch_size (int, defaults to 8): the batch size to use for training
+ buffer_limit (int, defaults to 0): the max_size limitation of replay buffer
+ buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
+ eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
+ value_clip (float, defaults to 0.4): the clip coefficient of value loss
+ experience_batch_size (int, defaults to 8): the batch size to use for experience generation
+ max_epochs (int, defaults to 1): the number of epochs of training process
+ dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
+ callbacks (List[Callback], defaults to []): the callbacks to call during training process
+ generate_kwargs (dict, optional): the kwargs to use while model generating
+ """
+
+ def __init__(
+ self,
+ experience_maker_holder_name_list: List[str],
+ strategy_fn: Callable[[], Strategy],
+ model_fn: Callable[[], Tuple[Actor, Critic]],
+ env_info: Dict[str, str] = None,
+ train_batch_size: int = 8,
+ buffer_limit: int = 0,
+ eps_clip: float = 0.2,
+ value_clip: float = 0.4,
+ dataloader_pin_memory: bool = True,
+ callbacks: List[TrainerCallback] = [],
+ eval_performance: bool = False,
+ debug: bool = False,
+ update_lora_weights: bool = False,
+ ) -> None:
+ # set environment variables
+ if env_info:
+ set_dist_env(env_info=env_info)
+ # configure strategy
+ self.strategy = strategy_fn()
+ # configure models, loss and optimizers
+ with self.strategy.model_init_context():
+ self.actor, self.critic = model_fn()
+
+ if eval_performance:
+ actor_numel = get_model_numel(self.actor)
+ critic_numel = get_model_numel(self.critic)
+ evaluator = TrainerPerformanceEvaluator(actor_numel, critic_numel)
+ callbacks = callbacks + [evaluator]
+
+ if isinstance(self.strategy, (LowLevelZeroStrategy, GeminiStrategy)):
+ self.actor_optim = HybridAdam(self.actor.parameters(), lr=1e-7)
+ self.critic_optim = HybridAdam(self.critic.parameters(), lr=1e-7)
+ else:
+ self.actor_optim = Adam(self.actor.parameters(), lr=1e-7)
+ self.critic_optim = Adam(self.critic.parameters(), lr=1e-7)
+
+ (self.actor, self.actor_optim), (self.critic, self.critic_optim) = self.strategy.prepare(
+ (self.actor, self.actor_optim), (self.critic, self.critic_optim)
+ )
+
+ # configure trainer
+ self.actor_loss_fn = PolicyLoss(eps_clip)
+ self.critic_loss_fn = ValueLoss(value_clip)
+
+ super().__init__(
+ experience_maker_holder_name_list,
+ train_batch_size=train_batch_size,
+ buffer_limit=buffer_limit,
+ dataloader_pin_memory=dataloader_pin_memory,
+ callbacks=callbacks,
+ debug=debug,
+ )
+ if self._debug:
+ print(f"[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}")
+
+ self._update_lora_weights = update_lora_weights
+
+ @ray.method(concurrency_group="model_io")
+ @torch.no_grad()
+ def _update_remote_makers(self, fully_update: bool = False, **config):
+ # TODO: balance duties
+ if not fully_update:
+ config["requires_grad_only"] = True
+ self.update_target_holder_list()
+ # mark start, ensure order
+ tasks = []
+ for target_holder in self.target_holder_list:
+ tasks.append(target_holder.update_experience_maker.remote(chunk_start=True, fully_update=fully_update))
+ ray.get(tasks)
+ # sending loop
+ tasks = []
+
+ for state_dict_shard in self._get_model_state_dict_shard(self.actor, fully_update=fully_update, **config):
+ for target_holder in self.target_holder_list:
+ tasks.append(
+ target_holder.update_experience_maker.remote(
+ new_actor_state_dict=state_dict_shard,
+ new_actor_lora_config_dict=self._get_model_lora_config_dict(self.actor),
+ fully_update=fully_update,
+ )
+ )
+ # sending loop
+ for state_dict_shard in self._get_model_state_dict_shard(self.critic, fully_update=fully_update, **config):
+ for target_holder in self.target_holder_list:
+ tasks.append(
+ target_holder.update_experience_maker.remote(
+ new_critic_state_dict=state_dict_shard,
+ new_critic_lora_config_dict=self._get_model_lora_config_dict(self.critic),
+ fully_update=fully_update,
+ )
+ )
+ ray.get(tasks)
+ # mark end
+ for target_holder in self.target_holder_list:
+ target_holder.update_experience_maker.remote(chunk_end=True, fully_update=fully_update)
+
+ @ray.method(concurrency_group="compute")
+ def training_step(self, experience: Experience) -> Dict[str, float]:
+ self.actor.train()
+ self.critic.train()
+
+ num_actions = experience.action_mask.size(1)
+ action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
+ actor_loss = self.actor_loss_fn(
+ action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
+ )
+ self.strategy.backward(actor_loss, self.actor, self.actor_optim)
+ self.strategy.optimizer_step(self.actor_optim)
+ self.actor_optim.zero_grad()
+
+ values = self.critic(
+ experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
+ )
+ critic_loss = self.critic_loss_fn(
+ values, experience.values, experience.reward, action_mask=experience.action_mask
+ )
+
+ self.strategy.backward(critic_loss, self.critic, self.critic_optim)
+ self.strategy.optimizer_step(self.critic_optim)
+ self.critic_optim.zero_grad()
+ return {"actor_loss": actor_loss.item(), "critic_loss": critic_loss.item()}
+
+ def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None:
+ self.strategy.save_model(self.actor, path, only_rank0)
+
+ def strategy_save_critic(self, path: str, only_rank0: bool = False) -> None:
+ self.strategy.save_model(self.critic, path, only_rank0)
+
+ def strategy_save_actor_optim(self, path: str, only_rank0: bool = False) -> None:
+ self.strategy.save_optimizer(self.actor_optim, path, only_rank0)
+
+ def strategy_save_critic_optim(self, path: str, only_rank0: bool = False) -> None:
+ self.strategy.save_optimizer(self.critic_optim, path, only_rank0)
+
+ def _get_model_state_dict_shard(self, model: torch.nn.Module, fully_update=False, **config):
+ for state_dict in self.strategy.get_model_state_dict_shard(model, **config):
+ if not self._update_lora_weights or fully_update:
+ yield state_dict_to(state_dict)
+ else:
+ state_dict_lora, _ = LoRAConstructor.filter_state_dict_lora(state_dict)
+ yield state_dict_to(state_dict_lora)
+
+ def _get_model_lora_config_dict(self, model: torch.nn.Module):
+ if not self._update_lora_weights:
+ return None
+ unwrapped_model = self.strategy.unwrap_model(model)
+ return LoRAConstructor.extract_lora_config(unwrapped_model)
diff --git a/applications/Chat/coati/ray/example/1m1t.py b/applications/Chat/coati/ray/example/1m1t.py
deleted file mode 100644
index a6527370505b9ce87a1985b8b3d45bcbf0c103a3..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/ray/example/1m1t.py
+++ /dev/null
@@ -1,153 +0,0 @@
-import argparse
-from copy import deepcopy
-
-import pandas as pd
-import torch
-from coati.trainer import PPOTrainer
-
-
-from coati.ray.src.experience_maker_holder import ExperienceMakerHolder
-from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer
-
-from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
-from coati.experience_maker import NaiveExperienceMaker
-from torch.optim import Adam
-from transformers import AutoTokenizer, BloomTokenizerFast
-from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
-
-from colossalai.nn.optimizer import HybridAdam
-
-import ray
-import os
-import socket
-
-def get_free_port():
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.bind(('', 0))
- return s.getsockname()[1]
-
-
-def get_local_ip():
- with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
- s.connect(('8.8.8.8', 80))
- return s.getsockname()[0]
-
-def main(args):
- master_addr = str(get_local_ip())
- # trainer_env_info
- trainer_port = str(get_free_port())
- env_info_trainer = {'local_rank' : '0',
- 'rank' : '0',
- 'world_size' : '1',
- 'master_port' : trainer_port,
- 'master_addr' : master_addr}
-
- # maker_env_info
- maker_port = str(get_free_port())
- env_info_maker = {'local_rank' : '0',
- 'rank' : '0',
- 'world_size' : '1',
- 'master_port' : maker_port,
- 'master_addr' : master_addr}
-
- # configure tokenizer
- if args.model == 'gpt2':
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
- tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'bloom':
- tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
- tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'opt':
- tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
- else:
- raise ValueError(f'Unsupported model "{args.model}"')
-
- # configure Trainer
- trainer_ref = DetachedPPOTrainer.options(name="trainer1", num_gpus=1, max_concurrency=2).remote(
- experience_maker_holder_name_list=["maker1"],
- strategy=args.trainer_strategy,
- model=args.model,
- env_info = env_info_trainer,
- pretrained=args.pretrain,
- lora_rank=args.lora_rank,
- train_batch_size=args.train_batch_size,
- buffer_limit=16,
- experience_batch_size=args.experience_batch_size,
- max_epochs=args.max_epochs,
- #kwargs:
- max_length=128,
- do_sample=True,
- temperature=1.0,
- top_k=50,
- pad_token_id=tokenizer.pad_token_id,
- eos_token_id=tokenizer.eos_token_id,
- debug=args.debug,
- )
-
- # configure Experience Maker
- experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote(
- detached_trainer_name_list=["trainer1"],
- strategy=args.maker_strategy,
- env_info = env_info_maker,
- experience_batch_size=args.experience_batch_size,
- kl_coef=0.1,
- #kwargs:
- max_length=128,
- do_sample=True,
- temperature=1.0,
- top_k=50,
- pad_token_id=tokenizer.pad_token_id,
- eos_token_id=tokenizer.eos_token_id,
- debug=args.debug,
- )
-
- # trainer send its actor and critic to experience holders.
- ray.get(trainer_ref.initialize_remote_makers.remote())
-
- # configure sampler
- dataset = pd.read_csv(args.prompt_path)['prompt']
-
- def tokenize_fn(texts):
- # MUST padding to max length to ensure inputs of all ranks have the same length
- # Different length may lead to hang when using gemini, as different generation steps
- batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
- return {k: v.cuda() for k, v in batch.items()}
-
- trainer_done_ref = trainer_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
- num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs + 3 # +3 for fault tolerance
- maker_done_ref = experience_holder_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)
-
- ray.get([trainer_done_ref, maker_done_ref])
-
- # save model checkpoint after fitting
- trainer_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
- # save optimizer checkpoint on all ranks
- if args.need_optim_ckpt:
- trainer_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
- only_rank0=False)
-
-if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('prompt_path')
- parser.add_argument('--trainer_strategy',
- choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
- default='naive')
- parser.add_argument('--maker_strategy',
- choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
- default='naive')
- parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
- parser.add_argument('--need_optim_ckpt', type=bool, default=False)
- parser.add_argument('--num_episodes', type=int, default=10)
- parser.add_argument('--max_timesteps', type=int, default=10)
- parser.add_argument('--update_timesteps', type=int, default=10)
- parser.add_argument('--max_epochs', type=int, default=5)
- parser.add_argument('--train_batch_size', type=int, default=8)
- parser.add_argument('--experience_batch_size', type=int, default=8)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
-
- parser.add_argument('--debug', action='store_true')
- args = parser.parse_args()
- ray.init(namespace=os.environ["RAY_NAMESPACE"])
- main(args)
diff --git a/applications/Chat/coati/ray/example/1m1t.sh b/applications/Chat/coati/ray/example/1m1t.sh
deleted file mode 100644
index f7c5054c800eb376ae973a5103482aec97b0511f..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/ray/example/1m1t.sh
+++ /dev/null
@@ -1,23 +0,0 @@
-set_n_least_used_CUDA_VISIBLE_DEVICES() {
- local n=${1:-"9999"}
- echo "GPU Memory Usage:"
- local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
- | tail -n +2 \
- | nl -v 0 \
- | tee /dev/tty \
- | sort -g -k 2 \
- | awk '{print $1}' \
- | head -n $n)
- export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
- echo "Now CUDA_VISIBLE_DEVICES is set to:"
- echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
-}
-
-set_n_least_used_CUDA_VISIBLE_DEVICES 2
-
-export RAY_NAMESPACE="admin"
-
-python 1m1t.py "/path/to/prompts.csv" \
- --trainer_strategy colossalai_zero2 --maker_strategy naive --lora_rank 2 --pretrain "facebook/opt-350m" --model 'opt' \
- --num_episodes 10 --max_timesteps 10 --update_timesteps 10 \
- --max_epochs 10 --debug
diff --git a/applications/Chat/coati/ray/example/1m2t.py b/applications/Chat/coati/ray/example/1m2t.py
deleted file mode 100644
index 3883c364a8e02fe0adcfa742599870d90f78631d..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/ray/example/1m2t.py
+++ /dev/null
@@ -1,186 +0,0 @@
-import argparse
-from copy import deepcopy
-
-import pandas as pd
-import torch
-from coati.trainer import PPOTrainer
-
-
-from coati.ray.src.experience_maker_holder import ExperienceMakerHolder
-from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer
-
-from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
-from coati.experience_maker import NaiveExperienceMaker
-from torch.optim import Adam
-from transformers import AutoTokenizer, BloomTokenizerFast
-from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
-
-from colossalai.nn.optimizer import HybridAdam
-
-import ray
-import os
-import socket
-
-
-def get_free_port():
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.bind(('', 0))
- return s.getsockname()[1]
-
-
-def get_local_ip():
- with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
- s.connect(('8.8.8.8', 80))
- return s.getsockname()[0]
-
-def main(args):
- master_addr = str(get_local_ip())
- # trainer_env_info
- trainer_port = str(get_free_port())
- env_info_trainer_1 = {'local_rank' : '0',
- 'rank' : '0',
- 'world_size' : '2',
- 'master_port' : trainer_port,
- 'master_addr' : master_addr}
- env_info_trainer_2 = {'local_rank' : '0',
- 'rank' : '1',
- 'world_size' : '2',
- 'master_port' : trainer_port,
- 'master_addr' : master_addr}
- # maker_env_info
- maker_port = str(get_free_port())
- env_info_maker_1 = {'local_rank' : '0',
- 'rank' : '0',
- 'world_size' : '2',
- 'master_port' : maker_port,
- 'master_addr' : master_addr}
- print([env_info_trainer_1,
- env_info_trainer_2,
- env_info_maker_1])
- ray.init(dashboard_port = 1145)
- # configure tokenizer
- if args.model == 'gpt2':
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
- tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'bloom':
- tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
- tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'opt':
- tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
- else:
- raise ValueError(f'Unsupported model "{args.model}"')
-
- # configure Trainer
- trainer_1_ref = DetachedPPOTrainer.options(name="trainer1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
- experience_maker_holder_name_list=["maker1"],
- strategy=args.trainer_strategy,
- model=args.model,
- env_info=env_info_trainer_1,
- pretrained=args.pretrain,
- lora_rank=args.lora_rank,
- train_batch_size=args.train_batch_size,
- buffer_limit=16,
- experience_batch_size=args.experience_batch_size,
- max_epochs=args.max_epochs,
- #kwargs:
- max_length=128,
- do_sample=True,
- temperature=1.0,
- top_k=50,
- pad_token_id=tokenizer.pad_token_id,
- eos_token_id=tokenizer.eos_token_id,
- debug=args.debug,
- )
-
- trainer_2_ref = DetachedPPOTrainer.options(name="trainer2", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
- experience_maker_holder_name_list=["maker1"],
- strategy=args.trainer_strategy,
- model=args.model,
- env_info=env_info_trainer_2,
- pretrained=args.pretrain,
- lora_rank=args.lora_rank,
- train_batch_size=args.train_batch_size,
- buffer_limit=16,
- experience_batch_size=args.experience_batch_size,
- max_epochs=args.max_epochs,
- #kwargs:
- max_length=128,
- do_sample=True,
- temperature=1.0,
- top_k=50,
- pad_token_id=tokenizer.pad_token_id,
- eos_token_id=tokenizer.eos_token_id,
- debug= args.debug,
- )
-
- # configure Experience Maker
- experience_holder_1_ref = ExperienceMakerHolder.options(name="maker1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
- detached_trainer_name_list=["trainer1", "trainer2"],
- strategy=args.maker_strategy,
- env_info=env_info_maker_1,
- experience_batch_size=args.experience_batch_size,
- kl_coef=0.1,
- #kwargs:
- max_length=128,
- do_sample=True,
- temperature=1.0,
- top_k=50,
- pad_token_id=tokenizer.pad_token_id,
- eos_token_id=tokenizer.eos_token_id,
- debug=args.debug,
- )
-
- # trainer send its actor and critic to experience holders.
- # TODO: balance duty
- ray.get(trainer_1_ref.initialize_remote_makers.remote())
-
- # configure sampler
- dataset = pd.read_csv(args.prompt_path)['prompt']
-
- def tokenize_fn(texts):
- # MUST padding to max length to ensure inputs of all ranks have the same length
- # Different length may lead to hang when using gemini, as different generation steps
- batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
- return {k: v.cuda() for k, v in batch.items()}
-
- trainer_1_done_ref = trainer_1_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
- trainer_2_done_ref = trainer_2_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
- num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs * 2 + 3 # +3 for fault tolerance
- maker_1_done_ref = experience_holder_1_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)
-
- ray.get([trainer_1_done_ref, trainer_2_done_ref, maker_1_done_ref])
- # save model checkpoint after fitting
- trainer_1_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
- trainer_2_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
- # save optimizer checkpoint on all ranks
- if args.need_optim_ckpt:
- trainer_1_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
- only_rank0=False)
- trainer_2_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
- only_rank0=False)
-
-
-if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('prompt_path')
- parser.add_argument('--trainer_strategy',
- choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
- default='naive')
- parser.add_argument('--maker_strategy',
- choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
- default='naive')
- parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
- parser.add_argument('--need_optim_ckpt', type=bool, default=False)
- parser.add_argument('--num_episodes', type=int, default=10)
- parser.add_argument('--max_timesteps', type=int, default=10)
- parser.add_argument('--update_timesteps', type=int, default=10)
- parser.add_argument('--max_epochs', type=int, default=5)
- parser.add_argument('--train_batch_size', type=int, default=8)
- parser.add_argument('--experience_batch_size', type=int, default=8)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
-
- parser.add_argument('--debug', action='store_true')
- args = parser.parse_args()
- main(args)
diff --git a/applications/Chat/coati/ray/example/1m2t.sh b/applications/Chat/coati/ray/example/1m2t.sh
deleted file mode 100644
index 669f4141026c25ef405d2502a506d1bb019520ea..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/ray/example/1m2t.sh
+++ /dev/null
@@ -1,23 +0,0 @@
-set_n_least_used_CUDA_VISIBLE_DEVICES() {
- local n=${1:-"9999"}
- echo "GPU Memory Usage:"
- local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
- | tail -n +2 \
- | nl -v 0 \
- | tee /dev/tty \
- | sort -g -k 2 \
- | awk '{print $1}' \
- | head -n $n)
- export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
- echo "Now CUDA_VISIBLE_DEVICES is set to:"
- echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
-}
-
-set_n_least_used_CUDA_VISIBLE_DEVICES 2
-
-export RAY_NAMESPACE="admin"
-
-python 1m2t.py "/path/to/prompts.csv" --model gpt2 \
- --maker_strategy naive --trainer_strategy ddp --lora_rank 2 \
- --num_episodes 10 --max_timesteps 10 --update_timesteps 10 \
- --max_epochs 10 #--debug
\ No newline at end of file
diff --git a/applications/Chat/coati/ray/example/2m1t.py b/applications/Chat/coati/ray/example/2m1t.py
deleted file mode 100644
index b655de1ab1fa987987fbd5954b6c94db98b27c50..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/ray/example/2m1t.py
+++ /dev/null
@@ -1,140 +0,0 @@
-import argparse
-from copy import deepcopy
-
-import pandas as pd
-import torch
-from coati.trainer import PPOTrainer
-
-
-from coati.ray.src.experience_maker_holder import ExperienceMakerHolder
-from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer
-
-from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
-from coati.experience_maker import NaiveExperienceMaker
-from torch.optim import Adam
-from transformers import AutoTokenizer, BloomTokenizerFast
-from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
-
-from colossalai.nn.optimizer import HybridAdam
-
-import ray
-import os
-import socket
-
-
-def main(args):
- # configure tokenizer
- if args.model == 'gpt2':
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
- tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'bloom':
- tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
- tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'opt':
- tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
- else:
- raise ValueError(f'Unsupported model "{args.model}"')
-
- # configure Trainer
- trainer_ref = DetachedPPOTrainer.options(name="trainer1", num_gpus=1, max_concurrency=2).remote(
- experience_maker_holder_name_list=["maker1", "maker2"],
- strategy=args.trainer_strategy,
- model=args.model,
- pretrained=args.pretrain,
- lora_rank=args.lora_rank,
- train_batch_size=args.train_batch_size,
- buffer_limit=16,
- experience_batch_size=args.experience_batch_size,
- max_epochs=args.max_epochs,
- #kwargs:
- max_length=128,
- do_sample=True,
- temperature=1.0,
- top_k=50,
- pad_token_id=tokenizer.pad_token_id,
- eos_token_id=tokenizer.eos_token_id,
- debug=args.debug,
- )
-
- # configure Experience Maker
- experience_holder_1_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote(
- detached_trainer_name_list=["trainer1"],
- strategy=args.maker_strategy,
- experience_batch_size=args.experience_batch_size,
- kl_coef=0.1,
- #kwargs:
- max_length=128,
- do_sample=True,
- temperature=1.0,
- top_k=50,
- pad_token_id=tokenizer.pad_token_id,
- eos_token_id=tokenizer.eos_token_id,
- debug=args.debug,
- )
-
- experience_holder_2_ref = ExperienceMakerHolder.options(name="maker2", num_gpus=1, max_concurrency=2).remote(
- detached_trainer_name_list=["trainer1"],
- strategy=args.maker_strategy,
- experience_batch_size=args.experience_batch_size,
- kl_coef=0.1,
- #kwargs:
- max_length=128,
- do_sample=True,
- temperature=1.0,
- top_k=50,
- pad_token_id=tokenizer.pad_token_id,
- eos_token_id=tokenizer.eos_token_id,
- debug=args.debug,
- )
-
- # trainer send its actor and critic to experience holders.
- ray.get(trainer_ref.initialize_remote_makers.remote())
-
- # configure sampler
- dataset = pd.read_csv(args.prompt_path)['prompt']
-
- def tokenize_fn(texts):
- # MUST padding to max length to ensure inputs of all ranks have the same length
- # Different length may lead to hang when using gemini, as different generation steps
- batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
- return {k: v.cuda() for k, v in batch.items()}
-
- trainer_done_ref = trainer_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
- num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs // 2 + 3 # +3 for fault tolerance
- maker_1_done_ref = experience_holder_1_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)
- maker_2_done_ref = experience_holder_2_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)
-
- ray.get([trainer_done_ref, maker_1_done_ref, maker_2_done_ref])
-
- # save model checkpoint after fitting
- trainer_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
- # save optimizer checkpoint on all ranks
- if args.need_optim_ckpt:
- trainer_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
- only_rank0=False)
-
-if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('prompt_path')
- parser.add_argument('--trainer_strategy',
- choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
- default='naive')
- parser.add_argument('--maker_strategy',
- choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
- default='naive')
- parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
- parser.add_argument('--need_optim_ckpt', type=bool, default=False)
- parser.add_argument('--num_episodes', type=int, default=10)
- parser.add_argument('--max_timesteps', type=int, default=10)
- parser.add_argument('--update_timesteps', type=int, default=10)
- parser.add_argument('--max_epochs', type=int, default=5)
- parser.add_argument('--train_batch_size', type=int, default=8)
- parser.add_argument('--experience_batch_size', type=int, default=8)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
-
- parser.add_argument('--debug', action='store_true')
- args = parser.parse_args()
- ray.init(namespace=os.environ["RAY_NAMESPACE"])
- main(args)
diff --git a/applications/Chat/coati/ray/example/2m1t.sh b/applications/Chat/coati/ray/example/2m1t.sh
deleted file mode 100644
index a207d4118d605a8c3a17882bcf3cc6b1a8f32eb3..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/ray/example/2m1t.sh
+++ /dev/null
@@ -1,23 +0,0 @@
-set_n_least_used_CUDA_VISIBLE_DEVICES() {
- local n=${1:-"9999"}
- echo "GPU Memory Usage:"
- local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
- | tail -n +2 \
- | nl -v 0 \
- | tee /dev/tty \
- | sort -g -k 2 \
- | awk '{print $1}' \
- | head -n $n)
- export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
- echo "Now CUDA_VISIBLE_DEVICES is set to:"
- echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
-}
-
-set_n_least_used_CUDA_VISIBLE_DEVICES 3
-
-export RAY_NAMESPACE="admin"
-
-python 2m1t.py "/path/to/prompts.csv" \
- --trainer_strategy naive --maker_strategy naive --lora_rank 2 --pretrain "facebook/opt-350m" --model 'opt' \
- --num_episodes 10 --max_timesteps 10 --update_timesteps 10 \
- --max_epochs 10 # --debug
diff --git a/applications/Chat/coati/ray/example/2m2t.py b/applications/Chat/coati/ray/example/2m2t.py
deleted file mode 100644
index 435c71915fc2820b9b12bbd2204823f12ec2438f..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/ray/example/2m2t.py
+++ /dev/null
@@ -1,209 +0,0 @@
-import argparse
-from copy import deepcopy
-
-import pandas as pd
-import torch
-from coati.trainer import PPOTrainer
-
-
-from coati.ray.src.experience_maker_holder import ExperienceMakerHolder
-from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer
-
-from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
-from coati.experience_maker import NaiveExperienceMaker
-from torch.optim import Adam
-from transformers import AutoTokenizer, BloomTokenizerFast
-from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
-
-from colossalai.nn.optimizer import HybridAdam
-
-import ray
-import os
-import socket
-
-
-def get_free_port():
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.bind(('', 0))
- return s.getsockname()[1]
-
-
-def get_local_ip():
- with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
- s.connect(('8.8.8.8', 80))
- return s.getsockname()[0]
-
-def main(args):
- master_addr = str(get_local_ip())
- # trainer_env_info
- trainer_port = str(get_free_port())
- env_info_trainer_1 = {'local_rank' : '0',
- 'rank' : '0',
- 'world_size' : '2',
- 'master_port' : trainer_port,
- 'master_addr' : master_addr}
- env_info_trainer_2 = {'local_rank' : '0',
- 'rank' : '1',
- 'world_size' : '2',
- 'master_port' : trainer_port,
- 'master_addr' : master_addr}
- # maker_env_info
- maker_port = str(get_free_port())
- env_info_maker_1 = {'local_rank' : '0',
- 'rank' : '0',
- 'world_size' : '2',
- 'master_port' : maker_port,
- 'master_addr' : master_addr}
- env_info_maker_2 = {'local_rank' : '0',
- 'rank' : '1',
- 'world_size' : '2',
- 'master_port': maker_port,
- 'master_addr' : master_addr}
- print([env_info_trainer_1,
- env_info_trainer_2,
- env_info_maker_1,
- env_info_maker_2])
- ray.init()
- # configure tokenizer
- if args.model == 'gpt2':
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
- tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'bloom':
- tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
- tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'opt':
- tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
- else:
- raise ValueError(f'Unsupported model "{args.model}"')
-
- # configure Trainer
- trainer_1_ref = DetachedPPOTrainer.options(name="trainer1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
- experience_maker_holder_name_list=["maker1", "maker2"],
- strategy=args.trainer_strategy,
- model=args.model,
- env_info=env_info_trainer_1,
- pretrained=args.pretrain,
- lora_rank=args.lora_rank,
- train_batch_size=args.train_batch_size,
- buffer_limit=16,
- experience_batch_size=args.experience_batch_size,
- max_epochs=args.max_epochs,
- #kwargs:
- max_length=128,
- do_sample=True,
- temperature=1.0,
- top_k=50,
- pad_token_id=tokenizer.pad_token_id,
- eos_token_id=tokenizer.eos_token_id,
- debug=args.debug,
- )
-
- trainer_2_ref = DetachedPPOTrainer.options(name="trainer2", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
- experience_maker_holder_name_list=["maker1", "maker2"],
- strategy=args.trainer_strategy,
- model=args.model,
- env_info=env_info_trainer_2,
- pretrained=args.pretrain,
- lora_rank=args.lora_rank,
- train_batch_size=args.train_batch_size,
- buffer_limit=16,
- experience_batch_size=args.experience_batch_size,
- max_epochs=args.max_epochs,
- #kwargs:
- max_length=128,
- do_sample=True,
- temperature=1.0,
- top_k=50,
- pad_token_id=tokenizer.pad_token_id,
- eos_token_id=tokenizer.eos_token_id,
- debug=args.debug,
- )
-
- # configure Experience Maker
- experience_holder_1_ref = ExperienceMakerHolder.options(name="maker1", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
- detached_trainer_name_list=["trainer1", "trainer2"],
- strategy=args.maker_strategy,
- env_info=env_info_maker_1,
- experience_batch_size=args.experience_batch_size,
- kl_coef=0.1,
- #kwargs:
- max_length=128,
- do_sample=True,
- temperature=1.0,
- top_k=50,
- pad_token_id=tokenizer.pad_token_id,
- eos_token_id=tokenizer.eos_token_id,
- debug=args.debug,
- )
-
- experience_holder_2_ref = ExperienceMakerHolder.options(name="maker2", namespace=os.environ["RAY_NAMESPACE"], num_gpus=1, max_concurrency=2).remote(
- detached_trainer_name_list=["trainer1", "trainer2"],
- strategy=args.maker_strategy,
- env_info=env_info_maker_2,
- experience_batch_size=args.experience_batch_size,
- kl_coef=0.1,
- #kwargs:
- max_length=128,
- do_sample=True,
- temperature=1.0,
- top_k=50,
- pad_token_id=tokenizer.pad_token_id,
- eos_token_id=tokenizer.eos_token_id,
- debug=args.debug,
- )
-
- # trainer send its actor and critic to experience holders.
- # TODO: balance duty
- ray.get(trainer_1_ref.initialize_remote_makers.remote())
-
- # configure sampler
- dataset = pd.read_csv(args.prompt_path)['prompt']
-
- def tokenize_fn(texts):
- # MUST padding to max length to ensure inputs of all ranks have the same length
- # Different length may lead to hang when using gemini, as different generation steps
- batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
- return {k: v.cuda() for k, v in batch.items()}
-
- trainer_1_done_ref = trainer_1_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
- trainer_2_done_ref = trainer_2_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps)
- num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs + 3 # +3 for fault tolerance
- maker_1_done_ref = experience_holder_1_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)
- maker_2_done_ref = experience_holder_2_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker)
-
- ray.get([trainer_1_done_ref, trainer_2_done_ref, maker_1_done_ref, maker_2_done_ref])
- # save model checkpoint after fitting
- trainer_1_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
- trainer_2_ref.strategy_save_actor.remote(args.save_path, only_rank0=True)
- # save optimizer checkpoint on all ranks
- if args.need_optim_ckpt:
- trainer_1_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
- only_rank0=False)
- trainer_2_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
- only_rank0=False)
-
-
-if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('prompt_path')
- parser.add_argument('--trainer_strategy',
- choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
- default='naive')
- parser.add_argument('--maker_strategy',
- choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
- default='naive')
- parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
- parser.add_argument('--need_optim_ckpt', type=bool, default=False)
- parser.add_argument('--num_episodes', type=int, default=10)
- parser.add_argument('--max_timesteps', type=int, default=10)
- parser.add_argument('--update_timesteps', type=int, default=10)
- parser.add_argument('--max_epochs', type=int, default=5)
- parser.add_argument('--train_batch_size', type=int, default=8)
- parser.add_argument('--experience_batch_size', type=int, default=8)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
-
- parser.add_argument('--debug', action='store_true')
- args = parser.parse_args()
- main(args)
diff --git a/applications/Chat/coati/ray/example/2m2t.sh b/applications/Chat/coati/ray/example/2m2t.sh
deleted file mode 100644
index fb4024766c54182efcd752f991bfc11c1db7588e..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/ray/example/2m2t.sh
+++ /dev/null
@@ -1,23 +0,0 @@
-set_n_least_used_CUDA_VISIBLE_DEVICES() {
- local n=${1:-"9999"}
- echo "GPU Memory Usage:"
- local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
- | tail -n +2 \
- | nl -v 0 \
- | tee /dev/tty \
- | sort -g -k 2 \
- | awk '{print $1}' \
- | head -n $n)
- export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
- echo "Now CUDA_VISIBLE_DEVICES is set to:"
- echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
-}
-
-set_n_least_used_CUDA_VISIBLE_DEVICES 2
-
-export RAY_NAMESPACE="admin"
-
-python 2m2t.py "path/to/prompts.csv" \
- --maker_strategy naive --trainer_strategy colossalai_zero2 --lora_rank 2 \
- --num_episodes 10 --max_timesteps 10 --update_timesteps 10 \
- --max_epochs 10 --debug
\ No newline at end of file
diff --git a/applications/Chat/coati/ray/experience_maker_holder.py b/applications/Chat/coati/ray/experience_maker_holder.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d290f4aba886434dfa86b12e7db05c14ac15d9c
--- /dev/null
+++ b/applications/Chat/coati/ray/experience_maker_holder.py
@@ -0,0 +1,274 @@
+import os
+import time
+import tracemalloc
+from threading import Lock
+from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
+
+import ray
+import torch
+from coati.experience_buffer.utils import split_experience_batch
+from coati.experience_maker import Experience, NaiveExperienceMaker
+from coati.models.base import Actor, Critic, RewardModel
+from coati.trainer.strategies import Strategy
+from torch import Tensor
+from tqdm import tqdm
+
+from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback
+from .lora_constructor import LoRAConstructor
+from .utils import get_model_numel, get_rank, is_rank_0, set_dist_env, state_dict_to
+
+
+@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
+class ExperienceMakerHolder:
+ """
+ Args:
+ detached_trainer_name_list: str list to get ray actor handles
+ strategy:
+ kl_coef: the coefficient of kl divergence loss
+ sync_models_from_trainers: whether to sync models from trainers. If True, you must call sync_models_to_remote_makers() in trainers to sync models.
+ """
+
+ def __init__(
+ self,
+ detached_trainer_name_list: List[str],
+ strategy_fn: Callable[[], Strategy],
+ # a function returns (actor, critic, reward_model, initial_model)
+ model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
+ env_info: Dict[str, str] = None,
+ sync_models_from_trainers: bool = False,
+ buffer_cpu_offload: bool = True,
+ kl_coef: float = 0.1,
+ callbacks: List[MakerCallback] = [],
+ eval_performance: bool = False,
+ debug: bool = False,
+ update_lora_weights: bool = False,
+ **generate_kwargs,
+ ):
+ # set environment variables
+ if env_info:
+ set_dist_env(env_info=env_info)
+ self.target_trainer_list = []
+ assert len(detached_trainer_name_list) > 0
+ self._detached_trainer_name_list = detached_trainer_name_list
+ self.strategy = strategy_fn()
+ self.buffer_cpu_offload = buffer_cpu_offload
+ self.kl_coef = kl_coef
+ # init models
+ with self.strategy.model_init_context():
+ actor, critic, reward_model, initial_model = model_fn()
+ self.generate_kwargs = _set_default_generate_kwargs(generate_kwargs, actor)
+ if eval_performance:
+ actor_numel = get_model_numel(actor)
+ critic_numel = get_model_numel(critic)
+ initial_model_numel = get_model_numel(initial_model)
+ reward_model_numel = get_model_numel(reward_model)
+ evaluator = ExperienceMakerPerformanceEvaluator(
+ actor_numel, critic_numel, initial_model_numel, reward_model_numel
+ )
+ callbacks = callbacks + [evaluator]
+
+ actor, critic, reward_model, initial_model = self.strategy.prepare(actor, critic, reward_model, initial_model)
+ self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, self.kl_coef)
+ self.callbacks = callbacks
+
+ self._model_visit_lock = Lock()
+
+ self._is_fully_initialized = not sync_models_from_trainers
+
+ self._debug = debug
+ self._update_lora_weights = update_lora_weights
+ if self._update_lora_weights:
+ self.actor_lora_constructor = LoRAConstructor()
+ self.critic_lora_constructor = LoRAConstructor()
+
+ self.target_auto_balance = False
+
+ self._target_idx = 0
+
+ if self._debug:
+ print(f"[maker{get_rank()}] will send items to {self._detached_trainer_name_list}")
+ if not self._is_fully_initialized:
+ print(f"[maker{get_rank()}] Waiting for INIT")
+
+ def _get_ready(self):
+ while not self._fully_initialized():
+ time.sleep(1.0)
+
+ def _fully_initialized(self):
+ return self._is_fully_initialized
+
+ def _init_target_trainer_list(self):
+ if len(self.target_trainer_list) > 0:
+ return
+ for name in self._detached_trainer_name_list:
+ self.target_trainer_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
+
+ # copy from ../trainer/base.py
+ @ray.method(concurrency_group="compute")
+ def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
+ if isinstance(inputs, Tensor):
+ return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
+ elif isinstance(inputs, dict):
+ return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
+ else:
+ raise ValueError(f'Unsupported input type "{type(inputs)}"')
+
+ @ray.method(concurrency_group="experience_io")
+ def _send_items(self, experience: Experience) -> None:
+ self._init_target_trainer_list()
+ items = split_experience_batch(experience)
+ items_per_trainer = [[] for _ in range(len(self.target_trainer_list))]
+ for item in items:
+ items_per_trainer[self._target_idx].append(item)
+ self._target_idx = (self._target_idx + 1) % len(self.target_trainer_list)
+ for i, target_trainer in enumerate(self.target_trainer_list):
+ if len(items_per_trainer[i]) > 0:
+ target_trainer.buffer_extend.remote(items_per_trainer[i])
+
+ def _inference_step(self, batch) -> None:
+ self._on_batch_start()
+ with self._model_visit_lock:
+ self._on_make_experience_start()
+ experience = self._make_experience(batch)
+ self._on_make_experience_end(experience)
+ self._on_send_start()
+ if self.buffer_cpu_offload:
+ experience.to_device("cpu")
+ self._send_items(experience)
+ self._on_send_end()
+ self._on_batch_end()
+
+ def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1, num_steps: int = 0):
+ """Working loop of the experience maker.
+
+ Args:
+ dataloader_fn (Callable[[], Iterable]): A function that returns a dataloader.
+ num_epochs (int, optional): Iterate the dataloader for number of epochs. Defaults to 1.
+ num_steps (int, optional): Iterate the dataloader for number if steps. If this value > 0, num_epochs will be ignored. Defaults to 0.
+ """
+ self._get_ready()
+ self._on_loop_start()
+ dataloader = dataloader_fn()
+ if num_steps > 0:
+ # ignore num epochs
+ it = iter(dataloader)
+ for _ in tqdm(range(num_steps), desc="ExperienceMaker", disable=not is_rank_0()):
+ try:
+ batch = next(it)
+ except StopIteration:
+ it = iter(dataloader)
+ batch = next(it)
+ self._inference_step(batch)
+ else:
+ with tqdm(total=num_epochs * len(dataloader), desc="ExperienceMaker", disable=not is_rank_0()) as pbar:
+ for _ in range(num_epochs):
+ for batch in dataloader:
+ self._inference_step(batch)
+ pbar.update()
+ self._on_loop_end()
+
+ @ray.method(concurrency_group="model_io")
+ def update_experience_maker(
+ self,
+ new_actor_state_dict: Dict[str, Any] = None,
+ new_actor_lora_config_dict: Dict[str, Any] = None,
+ new_critic_state_dict: Dict[str, Any] = None,
+ new_critic_lora_config_dict: Dict[str, Any] = None,
+ fully_update: bool = False,
+ chunk_start: bool = None,
+ chunk_end: bool = None,
+ ):
+ """
+ called by trainer
+ chunk_start: Set True at the first call. Before sending state_dict calls
+ chunk_end: Set True at the last call. After sending state_dict calls.
+ fully_update: Set True if you want to sync models when initializing
+
+ TODO: load_state_dict integrate with model-sharding strategy
+ """
+ _watch_memory = self._debug
+ if chunk_start:
+ if self._debug:
+ print("[maker] UPDATE ")
+ if _watch_memory:
+ tracemalloc.start()
+ self._model_visit_lock.acquire()
+
+ with torch.no_grad():
+ if new_actor_state_dict is not None:
+ if not self._update_lora_weights or fully_update:
+ self.experience_maker.actor.model.load_state_dict(new_actor_state_dict, strict=False)
+ else:
+ new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
+ state_dict_increase = self.actor_lora_constructor.reconstruct_increase(
+ new_actor_state_dict, new_actor_lora_config_dict
+ )
+ self.actor_lora_constructor.load_state_dict_increase(
+ self.experience_maker.actor.model, state_dict_increase
+ )
+ if new_critic_state_dict is not None:
+ if not self._update_lora_weights or fully_update:
+ self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
+ else:
+ new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
+ state_dict_increase = self.critic_lora_constructor.reconstruct_increase(
+ new_critic_state_dict, new_critic_lora_config_dict
+ )
+ self.critic_lora_constructor.load_state_dict_increase(
+ self.experience_maker.critic, state_dict_increase
+ )
+
+ # the lock must be released after both actor and critic being updated
+ if chunk_end:
+ self._model_visit_lock.release()
+ if _watch_memory:
+ current, peak = tracemalloc.get_traced_memory()
+ print(f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB")
+ tracemalloc.stop()
+ if fully_update:
+ self._is_fully_initialized = True
+
+ def _on_make_experience_start(self) -> None:
+ for callback in self.callbacks:
+ callback.on_make_experience_start()
+
+ def _on_make_experience_end(self, experience: Experience) -> None:
+ for callback in self.callbacks:
+ callback.on_make_experience_end(experience)
+
+ def _on_loop_start(self) -> None:
+ for callback in self.callbacks:
+ callback.on_loop_start()
+
+ def _on_loop_end(self) -> None:
+ for callback in self.callbacks:
+ callback.on_loop_end()
+
+ def _on_send_start(self) -> None:
+ for callback in self.callbacks:
+ callback.on_send_start()
+
+ def _on_send_end(self) -> None:
+ for callback in self.callbacks:
+ callback.on_send_end()
+
+ def _on_batch_start(self) -> None:
+ for callback in self.callbacks:
+ callback.on_batch_start()
+
+ def _on_batch_end(self) -> None:
+ for callback in self.callbacks:
+ callback.on_batch_end()
+
+
+def _set_default_generate_kwargs(generate_kwargs: dict, actor: Actor) -> None:
+ origin_model = actor.model
+ new_kwargs = {**generate_kwargs}
+ # use huggingface models method directly
+ if "prepare_inputs_fn" not in generate_kwargs and hasattr(origin_model, "prepare_inputs_for_generation"):
+ new_kwargs["prepare_inputs_fn"] = origin_model.prepare_inputs_for_generation
+
+ if "update_model_kwargs_fn" not in generate_kwargs and hasattr(origin_model, "_update_model_kwargs_for_generation"):
+ new_kwargs["update_model_kwargs_fn"] = origin_model._update_model_kwargs_for_generation
+
+ return new_kwargs
diff --git a/applications/Chat/coati/ray/lora_constructor.py b/applications/Chat/coati/ray/lora_constructor.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e9f78700e29fcc897245d486fd8c7c2bc1e95bc
--- /dev/null
+++ b/applications/Chat/coati/ray/lora_constructor.py
@@ -0,0 +1,123 @@
+from collections import OrderedDict
+from dataclasses import dataclass
+from typing import Any, Dict
+
+import torch.nn as nn
+from coati.models.lora import LoraLinear
+
+
+@dataclass
+class LoRAConfig:
+ r: int = 0
+ lora_alpha: int = 1
+ lora_dropout: float = 0
+ fan_in_fan_out: bool = False
+
+
+class LoRAConstructor:
+ """
+ Tools for reconstructing a model from a remote LoRA model.
+ (Transferring only LoRA data costs much less!)
+ Usage:
+ Step 1 (Sender):
+ filter_state_dict_lora()
+
+ Step 2 (Sender, Optional):
+ extract_lora_config()
+
+ Step 3 (Sender):
+ send state_dict_lora and lora_config_dict
+
+ Step 4 (Receiver):
+ reconstruct_increase()
+
+ Step 5 (Receiver):
+ load_state_dict_increase()
+
+ """
+
+ def __init__(self):
+ self.lora_config_dict = None
+
+ def register_lora_config(self, lora_config_dict: Dict[str, Any]):
+ self.lora_config_dict = lora_config_dict
+
+ def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict: Dict[str, Any]):
+ """
+ xxx.lora_A, xxx.lora_B -->> xxx.weight
+ Warning: the xxx.weight here is the increment actually.
+ """
+ if lora_config_dict is not None:
+ self.register_lora_config(lora_config_dict)
+
+ state_dict_increase = OrderedDict()
+ config_iter = iter(self.lora_config_dict.items())
+ lora_A, lora_B, layer_prefix = None, None, None
+ for k, v in state_dict_lora.items():
+ if k.rpartition(".")[-1] == "lora_A":
+ lora_A = v
+ layer_prefix = k.rpartition(".")[0]
+ elif k.rpartition(".")[-1] == "lora_B":
+ assert layer_prefix == k.rpartition(".")[0], "unmatched (lora_A, lora_B) pair"
+ layer_prefix_2, config = next(config_iter)
+ assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair"
+ lora_B = v
+ weight_data_increase = self._compute(lora_A, lora_B, config)
+ state_dict_increase[layer_prefix + ".weight"] = weight_data_increase
+ lora_A, lora_B, layer_prefix = None, None, None
+ else:
+ raise ValueError("unexpected key")
+ return state_dict_increase
+
+ def _compute(self, lora_A, lora_B, config=LoRAConfig()):
+ def T(w):
+ return w.T if config.fan_in_fan_out else w
+
+ if config.r > 0:
+ scaling = config.lora_alpha / config.r
+ weight_data_increase = T(lora_B @ lora_A) * scaling
+ return weight_data_increase
+ return 0
+
+ def load_state_dict_increase(self, model: nn.Module, state_dict_increase: Dict[str, Any]):
+ """
+ The final reconstruction step
+ """
+ # naive approach
+ model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increase.items()}, strict=False)
+
+ @staticmethod
+ def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False):
+ """
+ if keep_non_lora, also return non_lora state_dict
+ """
+ state_dict_lora = OrderedDict()
+ state_dict_non_lora = OrderedDict()
+ for k, v in state_dict.items():
+ if "lora_A" in k or "lora_B" in k:
+ state_dict_lora[k] = v
+ elif keep_non_lora:
+ state_dict_non_lora[k] = v
+ if keep_non_lora:
+ return state_dict_lora, state_dict_non_lora
+ else:
+ return state_dict_lora, None
+
+ @staticmethod
+ def extract_lora_config(model: nn.Module) -> Dict[str, LoRAConfig]:
+ """
+ extract LoraLinear model.
+ return OrderedDict(): name -> LoRAConfig
+ """
+ lora_config_dict = OrderedDict()
+
+ for name, child in model.named_modules():
+ if isinstance(child, LoraLinear):
+ lora_config_dict[name] = LoRAConfig(
+ r=child.r,
+ lora_alpha=child.lora_alpha,
+ lora_dropout=child.lora_dropout,
+ fan_in_fan_out=child.fan_in_fan_out,
+ )
+
+ return lora_config_dict
diff --git a/applications/Chat/coati/ray/src/detached_replay_buffer.py b/applications/Chat/coati/ray/src/detached_replay_buffer.py
deleted file mode 100644
index 855eee48c5a5ccb252ccd3e697073d728afcd3df..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/ray/src/detached_replay_buffer.py
+++ /dev/null
@@ -1,88 +0,0 @@
-import torch
-import random
-from typing import List, Any
-# from torch.multiprocessing import Queue
-from ray.util.queue import Queue
-import ray
-import asyncio
-from coati.experience_maker.base import Experience
-from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
-from coati.replay_buffer import ReplayBuffer
-from threading import Lock
-import copy
-
-class DetachedReplayBuffer:
- '''
- Detached replay buffer. Share Experience across workers on the same node.
- Therefore a trainer node is expected to have only one instance.
- It is ExperienceMakerHolder's duty to call append(exp) method, remotely.
-
- Args:
- sample_batch_size: Batch size when sampling. Exp won't enqueue until they formed a batch.
- tp_world_size: Number of workers in the same tp group
- limit: Limit of number of experience sample BATCHs. A number <= 0 means unlimited. Defaults to 0.
- cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True.
- '''
-
- def __init__(self, sample_batch_size: int, tp_world_size: int = 1, limit : int = 0, cpu_offload: bool = True) -> None:
- self.cpu_offload = cpu_offload
- self.sample_batch_size = sample_batch_size
- self.limit = limit
- self.items = Queue(self.limit, actor_options={"num_cpus":1})
- self.batch_collector : List[BufferItem] = []
-
- '''
- Workers in the same tp group share this buffer and need same sample for one step.
- Therefore a held_sample should be returned tp_world_size times before it could be dropped.
- worker_state records wheter a worker got the held_sample
- '''
- self.tp_world_size = tp_world_size
- self.worker_state = [False] * self.tp_world_size
- self.held_sample = None
- self._worker_state_lock = Lock()
-
- @torch.no_grad()
- def append(self, experience: Experience) -> None:
- '''
- Expected to be called remotely.
- '''
- if self.cpu_offload:
- experience.to_device(torch.device('cpu'))
- items = split_experience_batch(experience)
- self.batch_collector.extend(items)
- while len(self.batch_collector) >= self.sample_batch_size:
- items = self.batch_collector[:self.sample_batch_size]
- experience = make_experience_batch(items)
- self.items.put(experience, block=True)
- self.batch_collector = self.batch_collector[self.sample_batch_size:]
-
- def clear(self) -> None:
- # self.items.close()
- self.items.shutdown()
- self.items = Queue(self.limit)
- self.worker_state = [False] * self.tp_world_size
- self.batch_collector = []
-
- @torch.no_grad()
- def sample(self, worker_rank = 0, to_device = "cpu") -> Experience:
- self._worker_state_lock.acquire()
- if not any(self.worker_state):
- self.held_sample = self._sample_and_erase()
- self.worker_state[worker_rank] = True
- if all(self.worker_state):
- self.worker_state = [False] * self.tp_world_size
- ret = self.held_sample
- else:
- ret = copy.deepcopy(self.held_sample)
- self._worker_state_lock.release()
- ret.to_device(to_device)
- return ret
-
- @torch.no_grad()
- def _sample_and_erase(self) -> Experience:
- ret = self.items.get(block=True)
- return ret
-
- def get_length(self) -> int:
- ret = self.items.qsize()
- return ret
\ No newline at end of file
diff --git a/applications/Chat/coati/ray/src/detached_trainer_base.py b/applications/Chat/coati/ray/src/detached_trainer_base.py
deleted file mode 100644
index f1ed1ec71499a9e7f84c80a52e357ef4c2a92201..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/ray/src/detached_trainer_base.py
+++ /dev/null
@@ -1,121 +0,0 @@
-from abc import ABC, abstractmethod
-from typing import Any, Callable, Dict, List, Optional, Union
-from tqdm import tqdm
-from coati.trainer.callbacks import Callback
-from coati.experience_maker import Experience
-import ray
-import os
-
-from .detached_replay_buffer import DetachedReplayBuffer
-from .utils import is_rank_0
-
-class DetachedTrainer(ABC):
- '''
- Base class for detached rlhf trainers.
- 'detach' means that the experience maker is detached compared to a normal Trainer.
- Please set name attribute during init:
- >>> trainer = DetachedTrainer.options(..., name = "xxx", ...).remote()
- So an ExperienceMakerHolder can reach the detached_replay_buffer by Actor's name.
- Args:
- detached_strategy (DetachedStrategy): the strategy to use for training
- detached_replay_buffer_ref (ObjectRef[DetachedReplayBuffer]): the replay buffer to use for training
- experience_batch_size (int, defaults to 8): the batch size to use for experience generation
- max_epochs (int, defaults to 1): the number of epochs of training process
- data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
- callbacks (List[Callback], defaults to []): the callbacks to call during training process
- generate_kwargs (dict, optional): the kwargs to use while model generating
- '''
-
- def __init__(self,
- experience_maker_holder_name_list: List[str],
- train_batch_size: int = 8,
- buffer_limit: int = 0,
- buffer_cpu_offload: bool = True,
- experience_batch_size: int = 8,
- max_epochs: int = 1,
- dataloader_pin_memory: bool = True,
- callbacks: List[Callback] = [],
- **generate_kwargs) -> None:
- super().__init__()
- self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit, cpu_offload=buffer_cpu_offload)
- self.experience_batch_size = experience_batch_size
- self.max_epochs = max_epochs
- self.dataloader_pin_memory = dataloader_pin_memory
- self.callbacks = callbacks
- self.generate_kwargs = generate_kwargs
- self.target_holder_name_list = experience_maker_holder_name_list
- self.target_holder_list = []
-
- def update_target_holder_list(self, experience_maker_holder_name_list):
- self.target_holder_name_list = experience_maker_holder_name_list
- self.target_holder_list = []
- for name in self.target_holder_name_list:
- self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
-
- @abstractmethod
- def _update_remote_makers(self):
- pass
-
- @abstractmethod
- def training_step(self, experience: Experience) -> Dict[str, Any]:
- pass
-
- def _learn(self):
- pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
- for _ in pbar:
- if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
- print("[trainer] sampling exp")
- experience = self._buffer_sample()
- if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
- print("[trainer] training step")
- metrics = self.training_step(experience)
- if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
- print("[trainer] step over")
- pbar.set_postfix(metrics)
-
- def fit(self, num_episodes: int = 50000, max_timesteps: int = 500, update_timesteps: int = 5000) -> None:
- self._on_fit_start()
- for episode in range(num_episodes):
- self._on_episode_start(episode)
- for timestep in tqdm(range(max_timesteps // update_timesteps),
- desc=f'Episode [{episode+1}/{num_episodes}]',
- disable=not is_rank_0()):
- self._learn()
- self._update_remote_makers()
- self._on_episode_end(episode)
- self._on_fit_end()
-
- @ray.method(concurrency_group="buffer_length")
- def buffer_get_length(self):
- # called by ExperienceMakerHolder
- if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
- print("[trainer] telling length")
- return self.detached_replay_buffer.get_length()
-
- @ray.method(concurrency_group="buffer_append")
- def buffer_append(self, experience: Experience):
- # called by ExperienceMakerHolder
- if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
- # print(f"[trainer] receiving exp. Current buffer length: {self.detached_replay_buffer.get_length()}")
- print(f"[trainer] receiving exp.")
- self.detached_replay_buffer.append(experience)
-
- @ray.method(concurrency_group="buffer_sample")
- def _buffer_sample(self):
- return self.detached_replay_buffer.sample()
-
- def _on_fit_start(self) -> None:
- for callback in self.callbacks:
- callback.on_fit_start()
-
- def _on_fit_end(self) -> None:
- for callback in self.callbacks:
- callback.on_fit_end()
-
- def _on_episode_start(self, episode: int) -> None:
- for callback in self.callbacks:
- callback.on_episode_start(episode)
-
- def _on_episode_end(self, episode: int) -> None:
- for callback in self.callbacks:
- callback.on_episode_end(episode)
diff --git a/applications/Chat/coati/ray/src/detached_trainer_ppo.py b/applications/Chat/coati/ray/src/detached_trainer_ppo.py
deleted file mode 100644
index 838e82d07f4addb8924e1f808c2a47fed363c1f8..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/ray/src/detached_trainer_ppo.py
+++ /dev/null
@@ -1,192 +0,0 @@
-from typing import Any, Callable, Dict, List, Optional
-import torch
-from torch.optim import Adam
-
-from coati.experience_maker import Experience, NaiveExperienceMaker
-from coati.models.base import Actor, Critic
-from coati.models.generation_utils import update_model_kwargs_fn
-from coati.models.loss import PolicyLoss, ValueLoss
-from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy, Strategy
-from coati.trainer.callbacks import Callback
-
-from colossalai.nn.optimizer import HybridAdam
-
-import ray
-
-
-from .utils import is_rank_0, get_cuda_actor_critic_from_args, get_strategy_from_args, set_dist_env
-from .detached_trainer_base import DetachedTrainer
-
-
-@ray.remote(concurrency_groups={"buffer_length": 1, "buffer_append":1, "buffer_sample":1,"model_io": 1, "compute": 1})
-class DetachedPPOTrainer(DetachedTrainer):
- '''
- Detached Trainer for PPO algorithm
- Args:
- strategy (Strategy): the strategy to use for training
- model (str) : for actor / critic init
- pretrained (str) : for actor / critic init
- lora_rank (int) : for actor / critic init
- train_batch_size (int, defaults to 8): the batch size to use for training
- train_batch_size (int, defaults to 8): the batch size to use for training
- buffer_limit (int, defaults to 0): the max_size limitation of replay buffer
- buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
- eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
- value_clip (float, defaults to 0.4): the clip coefficient of value loss
- experience_batch_size (int, defaults to 8): the batch size to use for experience generation
- max_epochs (int, defaults to 1): the number of epochs of training process
- dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
- callbacks (List[Callback], defaults to []): the callbacks to call during training process
- generate_kwargs (dict, optional): the kwargs to use while model generating
- '''
-
- def __init__(self,
- experience_maker_holder_name_list: List[str],
- strategy: str,
- model: str,
- env_info: Dict[str, str] = None,
- pretrained: str = None,
- lora_rank: int = 0,
- train_batch_size: int = 8,
- buffer_limit: int = 0,
- buffer_cpu_offload: bool = True,
- eps_clip: float = 0.2,
- value_clip: float = 0.4,
- experience_batch_size: int = 8,
- max_epochs: int = 10,
- dataloader_pin_memory: bool = True,
- callbacks: List[Callback] = [],
- **generate_kwargs) -> None:
- # set environment variables
- if env_info:
- set_dist_env(env_info=env_info)
- # configure strategy
- self.strategy = get_strategy_from_args(strategy)
- # configure models, loss and optimizers
- with self.strategy.model_init_context():
- self.actor, self.critic = get_cuda_actor_critic_from_args(model, pretrained, lora_rank)
-
- if strategy != 'colossalai_gemini':
- self.actor.to(torch.float16).to(torch.cuda.current_device())
- self.critic.to(torch.float16).to(torch.cuda.current_device())
-
- if strategy.startswith('colossalai'):
- self.actor_optim = HybridAdam(self.actor.parameters(), lr=5e-6)
- self.critic_optim = HybridAdam(self.critic.parameters(), lr=5e-6)
- else:
- self.actor_optim = Adam(self.actor.parameters(), lr=5e-6)
- self.critic_optim = Adam(self.critic.parameters(), lr=5e-6)
-
- (self.actor, self.actor_optim), (self.critic, self.critic_optim) = \
- self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim))
- generate_kwargs = _set_default_generate_kwargs(self.strategy, generate_kwargs, self.actor)
-
- self.actor_loss_fn = PolicyLoss(eps_clip)
- self.critic_loss_fn = ValueLoss(value_clip)
-
- super().__init__(experience_maker_holder_name_list,
- train_batch_size=train_batch_size,
- buffer_limit=buffer_limit,
- buffer_cpu_offload=buffer_cpu_offload,
- experience_batch_size=experience_batch_size,
- max_epochs=max_epochs,
- dataloader_pin_memory=dataloader_pin_memory,
- callbacks=callbacks,
- **generate_kwargs)
-
- @ray.method(concurrency_group="model_io")
- def _update_remote_makers(self):
- # TODO: balance duties
- if is_rank_0():
- self.update_target_holder_list(self.target_holder_name_list)
- for target_holder in self.target_holder_list:
- # TODO: reduce malloc
- with torch.no_grad():
- ray.get(target_holder.update_experience_maker.remote(self._get_unwrapped_actor(), self._get_unwrapped_critic()))
-
- @ray.method(concurrency_group="model_io")
- def initialize_remote_makers(self):
- # TODO: balance duties
- if is_rank_0():
- self.update_target_holder_list(self.target_holder_name_list)
- for target_holder in self.target_holder_list:
- # TODO: reduce malloc
- with torch.no_grad():
- ray.get(target_holder.initialize_experience_maker.remote(self._get_unwrapped_actor(), self._get_unwrapped_critic()))
-
- @ray.method(concurrency_group="compute")
- def training_step(self, experience: Experience) -> Dict[str, float]:
- self.actor.train()
- self.critic.train()
-
- experience.to_device(torch.cuda.current_device())
- num_actions = experience.action_mask.size(1)
- action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
- actor_loss = self.actor_loss_fn(action_log_probs,
- experience.action_log_probs,
- experience.advantages,
- action_mask=experience.action_mask)
- self.strategy.backward(actor_loss, self.actor, self.actor_optim)
- self.strategy.optimizer_step(self.actor_optim)
- self.actor_optim.zero_grad()
-
- values = self.critic(experience.sequences,
- action_mask=experience.action_mask,
- attention_mask=experience.attention_mask)
- critic_loss = self.critic_loss_fn(values,
- experience.values,
- experience.reward,
- action_mask=experience.action_mask)
-
- self.strategy.backward(critic_loss, self.critic, self.critic_optim)
- self.strategy.optimizer_step(self.critic_optim)
- self.critic_optim.zero_grad()
- return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
-
- def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None:
- self.strategy.save_model(self.actor, path, only_rank0)
-
- def strategy_save_critic(self, path: str, only_rank0: bool = False) -> None:
- self.strategy.save_model(self.critic, path, only_rank0)
-
- def strategy_save_actor_optim(self, path: str, only_rank0: bool = False) -> None:
- self.strategy.save_optimizer(self.actor_optim, path, only_rank0)
-
- def strategy_save_critic_optim(self, path: str, only_rank0: bool = False) -> None:
- self.strategy.save_optimizer(self.critic_optim, path, only_rank0)
-
- def _get_unwrapped_actor(self):
- if False:
- pass
- elif isinstance(self.strategy, ColossalAIStrategy):
- ret = Actor(self.strategy._unwrap_model(self.actor))
- return ret
- elif isinstance(self.strategy, DDPStrategy):
- return Actor(self.strategy._unwrap_actor(self.actor))
- elif isinstance(self.strategy, NaiveStrategy):
- return self.actor
-
- def _get_unwrapped_critic(self):
- if False:
- pass
- elif isinstance(self.strategy, ColossalAIStrategy):
- ret = self.strategy._unwrap_model(self.critic)
- return ret
- elif isinstance(self.strategy, DDPStrategy):
- return self.critic.module
- elif isinstance(self.strategy, NaiveStrategy):
- return self.critic
-
-
-def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
- origin_model = strategy._unwrap_actor(actor)
- new_kwargs = {**generate_kwargs}
- # use huggingface models method directly
- if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
- new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
-
- if 'update_model_kwargs_fn' not in generate_kwargs:
- new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
-
- return new_kwargs
-
\ No newline at end of file
diff --git a/applications/Chat/coati/ray/src/experience_maker_holder.py b/applications/Chat/coati/ray/src/experience_maker_holder.py
deleted file mode 100644
index 94e4a3d537a57733ad8dbf510e940ef6c9129f6d..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/ray/src/experience_maker_holder.py
+++ /dev/null
@@ -1,172 +0,0 @@
-import torch
-from typing import Any, Callable, Dict, List, Optional, Union
-import ray
-from ray.exceptions import GetTimeoutError
-from torch import Tensor
-import torch.nn as nn
-from coati.models.base import Actor, Critic, RewardModel
-from coati.trainer.strategies.sampler import DistributedSampler
-from coati.trainer.strategies import Strategy
-from coati.experience_maker import NaiveExperienceMaker, Experience, ExperienceMaker
-
-from copy import deepcopy
-from threading import Lock
-import time
-import os
-
-
-from .utils import is_rank_0, get_strategy_from_args, set_dist_env
-
-
-@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
-class ExperienceMakerHolder:
- '''
- Args:
- detached_trainer_name_list: str list to get ray actor handleskkk
- strategy:
- experience_batch_size: batch size of generated experience
- kl_coef: the coefficient of kl divergence loss
- '''
-
- def __init__(self,
- detached_trainer_name_list: List[str],
- strategy: str,
- env_info: Dict[str, str] = None,
- experience_batch_size: int = 8,
- kl_coef: float = 0.1,
- **generate_kwargs):
- # set environment variables
- if env_info:
- set_dist_env(env_info=env_info)
- self.target_trainer_list = []
- for name in detached_trainer_name_list:
- self.target_trainer_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
- self.strategy_str = strategy
- self.strategy = get_strategy_from_args(strategy)
- self.experience_batch_size = experience_batch_size
- self.kl_coef = kl_coef
- self.generate_kwargs = generate_kwargs
- # Need a trainer to give an actor and a critic via initialize_experience_maker(...)
- actor, critic, reward_model, initial_model = None, None, None, None
- self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, self.kl_coef)
- self._model_visit_lock = Lock()
- self.fully_initialized = False
- if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
- print('[maker] Waiting for INIT')
-
- def _get_ready(self):
- while not self.fully_initialized:
- time.sleep(1.0)
-
- def update_target_trainer_list(self, detached_trainer_name_list):
- self.target_trainer_list = []
- for name in detached_trainer_name_list:
- self.target_trainer_list.append(ray.get_actor(name))
-
- # copy from ../trainer/base.py
- @ray.method(concurrency_group="compute")
- def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
- self._get_ready()
- if isinstance(inputs, Tensor):
- return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
- elif isinstance(inputs, dict):
- return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
- else:
- raise ValueError(f'Unsupported input type "{type(inputs)}"')
-
- @ray.method(concurrency_group="experience_io")
- def _send_experience(self, experience):
- '''
- ignore it
-
- # choose a trainer that has the least experience batch in its detached_replay_buffer
- chosen_trainer = None
- min_length = None
- if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
- print("[maker] choosing target trainer")
- while chosen_trainer is None:
- for target_trainer in self.target_trainer_list:
- try:
- temp_length = ray.get(target_trainer.buffer_get_length.remote(), timeout=0.1)
- if min_length is None:
- min_length = temp_length
- chosen_trainer = target_trainer
- else:
- if temp_length < min_length:
- min_length = temp_length
- chosen_trainer = target_trainer
- except GetTimeoutError:
- pass
-
- if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
- print(f"[maker] sending exp to {chosen_trainer}")
- chosen_trainer.buffer_append.remote(experience)
- '''
- #
- if not hasattr(self, "_target_idx"):
- self._target_idx = 0
- chosen_trainer = self.target_trainer_list[self._target_idx]
- if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
- print(f"[maker] sending exp to {chosen_trainer}")
- chosen_trainer.buffer_append.remote(experience)
- self._target_idx = (self._target_idx + 1) % len(self.target_trainer_list)
-
- def workingloop(self, dataset, tokenizer: Optional[Callable[[Any], dict]] = None, times=5000 * 50000):
- self._get_ready()
- sampler = self.strategy.setup_sampler(dataset)
- for _ in range(times):
- rand_prompts = sampler.sample(self.experience_batch_size)
- if tokenizer is not None:
- inputs = tokenizer(rand_prompts)
- else:
- inputs = rand_prompts
- self._model_visit_lock.acquire()
- experience = self._make_experience(inputs=inputs)
- self._model_visit_lock.release()
- self._send_experience(experience=experience)
-
- @ray.method(concurrency_group="model_io")
- def initialize_experience_maker(self, init_actor: Actor, init_critic: Critic):
- '''
- called by trainer. Only once.
- '''
- # TODO: reduce malloc
- if self.fully_initialized:
- return
- if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
- print('[maker] INIT')
- with torch.no_grad():
- with self.strategy.model_init_context():
- actor = init_actor
- critic = init_critic
- initial_model = deepcopy(actor)
- reward_model = RewardModel(deepcopy(critic.model),
- deepcopy(critic.value_head)).to(torch.cuda.current_device())
- if self.strategy_str != 'colossalai_gemini':
- actor.to(torch.float16).to(torch.cuda.current_device())
- critic.to(torch.float16).to(torch.cuda.current_device())
- initial_model.to(torch.float16).to(torch.cuda.current_device())
- reward_model.to(torch.float16).to(torch.cuda.current_device())
-
- self.experience_maker.actor = self.strategy.prepare(actor)
- self.experience_maker.critic = self.strategy.prepare(critic)
- self.experience_maker.initial_model = self.strategy.prepare(initial_model)
- self.experience_maker.reward_model = self.strategy.prepare(reward_model)
- self.fully_initialized = True
-
- @ray.method(concurrency_group="model_io")
- def update_experience_maker(self, new_actor: Actor, new_critic: Critic):
- '''
- called by trainer
- '''
- # TODO: reduce malloc
- self._model_visit_lock.acquire()
- with torch.no_grad():
- if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
- print("[maker] UPDATE ")
- if self.strategy_str != 'colossalai_gemini':
- new_actor.to(torch.float16).to(torch.cuda.current_device())
- new_critic.to(torch.float16).to(torch.cuda.current_device())
- self.experience_maker.actor = self.strategy.prepare(new_actor)
- self.experience_maker.critic = self.strategy.prepare(new_critic)
- self._model_visit_lock.release()
diff --git a/applications/Chat/coati/ray/src/pipeline_strategy.py b/applications/Chat/coati/ray/src/pipeline_strategy.py
deleted file mode 100644
index 1780839c62ee3477ce84cd412ccf7ac60ab57357..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/ray/src/pipeline_strategy.py
+++ /dev/null
@@ -1,105 +0,0 @@
-# WIP
-
-
-from coati.trainer.strategies import Strategy
-from coati.trainer.strategies import NaiveStrategy
-from coati.models.base import Actor, RewardModel, Critic
-
-import numpy as np
-import torch
-from torch._C._distributed_rpc import _is_current_rpc_agent_set
-
-import colossalai
-from colossalai.pipeline.pipeline_process_group import ppg
-from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
-from colossalai.fx import ColoTracer
-from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
-from colossalai.pipeline.middleware.adaptor import get_fx_topology
-
-
-import os
-from functools import partial
-import random
-
-rpc_is_initialized = _is_current_rpc_agent_set
-
-class PipelineModel(torch.nn.Module):
- '''
- Actor has 2 kinds of jobs: forward and generate.
- better to just pipelinize the inner model
- '''
- def __init__(self,
- model: torch.nn.Module,
- stage_num: int,
- num_microbatches: int,
- data_kwargs = None,
- ):
- super().__init__()
- # create partition module
- def create_partition_module(pp_rank:int, stage_num: int, model, data_kwargs):
- model.eval()
- tracer = ColoTracer()
- meta_args = {k: v.to('meta') for k, v in data_kwargs.items()}
- graph = tracer.trace(root=model, meta_args=meta_args)
- gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
- annotated_model = balanced_split_pass(gm, stage_num)
- top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True)
- topo = get_fx_topology(top_module)
- for submodule in split_submodules:
- if isinstance(submodule, torch.fx.GraphModule):
- setattr(submodule, '_topo', topo)
- return split_submodules[pp_rank + 1]
-
- def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int):
- partition = create_partition_module(pp_rank, stage_num, model, data_kwargs)
- return partition
- self.inference_engine = OneFOneBPipelineEngine(
- partition_fn=partial(partition, model, data_kwargs),
- stage_num=stage_num,
- num_microbatches=num_microbatches,
- device='cuda',
- )
-
- def forward(self,
- **model_inputs):
- return self.inference_engine.forward_backward(**model_inputs, forward_only=True)
-
-
-
-class PPStrategy(NaiveStrategy):
- """
- Strategy for Pipeline inference (inference only!)
-
- master node only
- """
- def __init__(
- self,
- seed: int = 42
- ):
- self.seed = seed
- super().__init__()
-
-
- def setup_distributed(self) -> None:
- colossalai.launch_from_torch({}, seed=self.seed)
- ppg.set_global_info(rank = int(os.environ['RANK']),
- world_size=int(os.environ['WORLD_SIZE']),
- dp_degree=1,
- tp_degree=1,
- num_worker_threads=128,
- device="cuda")
-
- def model_init_context(self):
- return super().model_init_context()
-
- def setup_model(self, model: torch.nn.Module) -> torch.nn.Module:
- if isinstance(model, Actor) or \
- isinstance(model, RewardModel) or \
- isinstance(model, Critic):
- model.model = PipelineModel(model.model)
-
- def set_seed(self, seed: int) -> None:
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
-
diff --git a/applications/Chat/coati/ray/src/utils.py b/applications/Chat/coati/ray/src/utils.py
deleted file mode 100644
index c750879b6d187b90bef14c15ddf50cd9ca33f2a5..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/ray/src/utils.py
+++ /dev/null
@@ -1,48 +0,0 @@
-import torch.distributed as dist
-from typing import Any, Callable, Dict, List, Optional
-from coati.models.bloom import BLOOMActor, BLOOMCritic
-from coati.models.gpt import GPTActor, GPTCritic
-from coati.models.opt import OPTActor, OPTCritic
-from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
-import torch
-import os
-
-def is_rank_0() -> bool:
- return not dist.is_initialized() or dist.get_rank() == 0
-
-
-def get_cuda_actor_critic_from_args(model: str, pretrained: str = None, lora_rank=0):
- if model == 'gpt2':
- actor = GPTActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
- critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
- elif model == 'bloom':
- actor = BLOOMActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
- critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
- elif model == 'opt':
- actor = OPTActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
- critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device())
- else:
- raise ValueError(f'Unsupported model "{model}"')
- return actor, critic
-
-
-def get_strategy_from_args(strategy: str):
- if strategy == 'naive':
- strategy_ = NaiveStrategy()
- elif strategy == 'ddp':
- strategy_ = DDPStrategy()
- elif strategy == 'colossalai_gemini':
- strategy_ = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
- elif strategy == 'colossalai_zero2':
- strategy_ = ColossalAIStrategy(stage=2, placement_policy='cuda')
- else:
- raise ValueError(f'Unsupported strategy "{strategy}"')
- return strategy_
-
-
-def set_dist_env(env_info: Dict[str, str]):
- os.environ["RANK"] = env_info['rank']
- os.environ["LOCAL_RANK"] = env_info['local_rank']
- os.environ["WORLD_SIZE"] = env_info['world_size']
- os.environ['MASTER_PORT'] = env_info['master_port']
- os.environ['MASTER_ADDR'] = env_info['master_addr']
diff --git a/applications/Chat/coati/ray/utils.py b/applications/Chat/coati/ray/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88140c0e036d323a50a6f5cf448e42483b66d8c
--- /dev/null
+++ b/applications/Chat/coati/ray/utils.py
@@ -0,0 +1,140 @@
+import os
+from collections import OrderedDict
+from typing import Any, Dict
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
+from coati.models.gpt import GPTRM, GPTActor, GPTCritic
+from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
+from coati.models.opt import OPTRM, OPTActor, OPTCritic
+from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
+from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer
+
+
+def is_rank_0() -> bool:
+ return not dist.is_initialized() or dist.get_rank() == 0
+
+
+def get_rank() -> int:
+ return dist.get_rank() if dist.is_initialized() else 0
+
+
+def get_world_size() -> int:
+ return dist.get_world_size() if dist.is_initialized() else 1
+
+
+def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
+ if model == "gpt2":
+ actor = GPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
+ elif model == "bloom":
+ actor = BLOOMActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
+ elif model == "opt":
+ actor = OPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
+ elif model == "llama":
+ actor = LlamaActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
+ else:
+ raise ValueError(f'Unsupported actor model "{model}"')
+ return actor
+
+
+def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
+ if model == "gpt2":
+ critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
+ elif model == "bloom":
+ critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
+ elif model == "opt":
+ critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
+ elif model == "llama":
+ critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config)
+ else:
+ raise ValueError(f'Unsupported reward model "{model}"')
+ return critic
+
+
+def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
+ if model == "gpt2":
+ reward_model = GPTRM(pretrained=pretrained, config=config)
+ elif model == "bloom":
+ reward_model = BLOOMRM(pretrained=pretrained, config=config)
+ elif model == "opt":
+ reward_model = OPTRM(pretrained=pretrained, config=config)
+ elif model == "llama":
+ reward_model = LlamaRM(pretrained=pretrained, config=config)
+ else:
+ raise ValueError(f'Unsupported reward model "{model}"')
+ return reward_model
+
+
+def get_strategy_from_args(strategy: str):
+ if strategy == "ddp":
+ strategy_ = DDPStrategy()
+ elif strategy == "colossalai_gemini":
+ strategy_ = GeminiStrategy(placement_policy="static", initial_scale=2**5)
+ elif strategy == "colossalai_zero2":
+ strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
+ elif strategy == "colossalai_gemini_cpu":
+ strategy_ = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
+ elif strategy == "colossalai_zero2_cpu":
+ strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
+ else:
+ raise ValueError(f'Unsupported strategy "{strategy}"')
+ return strategy_
+
+
+def get_tokenizer_from_args(model: str, **kwargs):
+ if model == "gpt2":
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
+ elif model == "bloom":
+ tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
+ elif model == "opt":
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
+ elif model == "llama":
+ pretrain_path = kwargs["pretrain"]
+ tokenizer = AutoTokenizer.from_pretrained(pretrain_path)
+ else:
+ raise ValueError(f'Unsupported model "{model}"')
+
+ tokenizer.pad_token = tokenizer.eos_token
+ return tokenizer
+
+
+def set_dist_env(env_info: Dict[str, str]):
+ os.environ["RANK"] = env_info["rank"]
+ os.environ["LOCAL_RANK"] = env_info["local_rank"]
+ os.environ["WORLD_SIZE"] = env_info["world_size"]
+ os.environ["MASTER_PORT"] = env_info["master_port"]
+ os.environ["MASTER_ADDR"] = env_info["master_addr"]
+
+
+def get_model_numel(model: nn.Module) -> int:
+ numel = sum(p.numel() for p in model.parameters())
+ return numel
+
+
+def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: int, allow_idle_sender: bool) -> list:
+ target_receivers = []
+ if num_senders <= num_receivers or allow_idle_sender:
+ # a sender will send data to one or more receivers
+ # a receiver only has one sender
+ for i in range(num_receivers):
+ if i % num_senders == sender_idx:
+ target_receivers.append(i)
+ else:
+ # a sender will send data to one receiver
+ # a receiver may have more than one sender
+ target_receivers.append(sender_idx % num_receivers)
+ return target_receivers
+
+
+def state_dict_to(
+ state_dict: Dict[str, Any], dtype: torch.dtype = torch.float16, device: torch.device = torch.device("cpu")
+):
+ """
+ keep state_dict intact
+ """
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ new_state_dict[k] = v.to(dtype=dtype, device=device)
+ return new_state_dict
diff --git a/applications/Chat/coati/replay_buffer/__init__.py b/applications/Chat/coati/replay_buffer/__init__.py
deleted file mode 100644
index 1ebf60382913ead7247197a6ae7b021ceb7e5d71..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/replay_buffer/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from .base import ReplayBuffer
-from .naive import NaiveReplayBuffer
-
-__all__ = ['ReplayBuffer', 'NaiveReplayBuffer']
diff --git a/applications/Chat/coati/replay_buffer/base.py b/applications/Chat/coati/replay_buffer/base.py
deleted file mode 100644
index 4c3812461a10120c358d4ddbccdffde4eff7fdfa..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/replay_buffer/base.py
+++ /dev/null
@@ -1,43 +0,0 @@
-from abc import ABC, abstractmethod
-from typing import Any
-
-from coati.experience_maker.base import Experience
-
-
-class ReplayBuffer(ABC):
- """Replay buffer base class. It stores experience.
-
- Args:
- sample_batch_size (int): Batch size when sampling.
- limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
- """
-
- def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
- super().__init__()
- self.sample_batch_size = sample_batch_size
- # limit <= 0 means unlimited
- self.limit = limit
-
- @abstractmethod
- def append(self, experience: Experience) -> None:
- pass
-
- @abstractmethod
- def clear(self) -> None:
- pass
-
- @abstractmethod
- def sample(self) -> Experience:
- pass
-
- @abstractmethod
- def __len__(self) -> int:
- pass
-
- @abstractmethod
- def __getitem__(self, idx: int) -> Any:
- pass
-
- @abstractmethod
- def collate_fn(self, batch: Any) -> Experience:
- pass
diff --git a/applications/Chat/coati/replay_buffer/naive.py b/applications/Chat/coati/replay_buffer/naive.py
deleted file mode 100644
index 938f500643c96c7370314bf6e60d3e38d765bc12..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/replay_buffer/naive.py
+++ /dev/null
@@ -1,57 +0,0 @@
-import random
-from typing import List
-
-import torch
-from coati.experience_maker.base import Experience
-
-from .base import ReplayBuffer
-from .utils import BufferItem, make_experience_batch, split_experience_batch
-
-
-class NaiveReplayBuffer(ReplayBuffer):
- """Naive replay buffer class. It stores experience.
-
- Args:
- sample_batch_size (int): Batch size when sampling.
- limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
- cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True.
- """
-
- def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True) -> None:
- super().__init__(sample_batch_size, limit)
- self.cpu_offload = cpu_offload
- self.target_device = torch.device(f'cuda:{torch.cuda.current_device()}')
- # TODO(ver217): add prefetch
- self.items: List[BufferItem] = []
-
- @torch.no_grad()
- def append(self, experience: Experience) -> None:
- if self.cpu_offload:
- experience.to_device(torch.device('cpu'))
- items = split_experience_batch(experience)
- self.items.extend(items)
- if self.limit > 0:
- samples_to_remove = len(self.items) - self.limit
- if samples_to_remove > 0:
- self.items = self.items[samples_to_remove:]
-
- def clear(self) -> None:
- self.items.clear()
-
- @torch.no_grad()
- def sample(self) -> Experience:
- items = random.sample(self.items, self.sample_batch_size)
- experience = make_experience_batch(items)
- if self.cpu_offload:
- experience.to_device(self.target_device)
- return experience
-
- def __len__(self) -> int:
- return len(self.items)
-
- def __getitem__(self, idx: int) -> BufferItem:
- return self.items[idx]
-
- def collate_fn(self, batch) -> Experience:
- experience = make_experience_batch(batch)
- return experience
diff --git a/applications/Chat/coati/replay_buffer/utils.py b/applications/Chat/coati/replay_buffer/utils.py
deleted file mode 100644
index 6ad0db2c3b609e0a3cfa4b1126550ebb44dec6a2..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/replay_buffer/utils.py
+++ /dev/null
@@ -1,73 +0,0 @@
-from dataclasses import dataclass
-from typing import List, Optional
-
-import torch
-import torch.nn.functional as F
-from coati.experience_maker.base import Experience
-
-
-@dataclass
-class BufferItem:
- """BufferItem is an item of experience data.
-
- Shapes of each tensor:
- sequences: (S)
- action_log_probs: (A)
- values: (1)
- reward: (1)
- advantages: (1)
- attention_mask: (S)
- action_mask: (A)
-
- "A" is the number of actions.
- """
- sequences: torch.Tensor
- action_log_probs: torch.Tensor
- values: torch.Tensor
- reward: torch.Tensor
- advantages: torch.Tensor
- attention_mask: Optional[torch.LongTensor]
- action_mask: Optional[torch.BoolTensor]
-
-
-def split_experience_batch(experience: Experience) -> List[BufferItem]:
- batch_size = experience.sequences.size(0)
- batch_kwargs = [{} for _ in range(batch_size)]
- keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
- for key in keys:
- value = getattr(experience, key)
- if isinstance(value, torch.Tensor):
- vals = torch.unbind(value)
- else:
- # None
- vals = [value for _ in range(batch_size)]
- assert batch_size == len(vals)
- for i, v in enumerate(vals):
- batch_kwargs[i][key] = v
- items = [BufferItem(**kwargs) for kwargs in batch_kwargs]
- return items
-
-
-def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor:
- assert side in ('left', 'right')
- max_len = max(seq.size(0) for seq in sequences)
- padded_sequences = []
- for seq in sequences:
- pad_len = max_len - seq.size(0)
- padding = (pad_len, 0) if side == 'left' else (0, pad_len)
- padded_sequences.append(F.pad(seq, padding))
- return torch.stack(padded_sequences, dim=0)
-
-
-def make_experience_batch(items: List[BufferItem]) -> Experience:
- kwargs = {}
- to_pad_keys = set(('action_log_probs', 'action_mask'))
- keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
- for key in keys:
- vals = [getattr(item, key) for item in items]
- if key in to_pad_keys:
- batch_data = zero_pad_sequences(vals)
- else:
- batch_data = torch.stack(vals, dim=0)
- kwargs[key] = batch_data
- return Experience(**kwargs)
diff --git a/applications/Chat/coati/trainer/__init__.py b/applications/Chat/coati/trainer/__init__.py
index 525b57bf21d351e6522102a8d39ff5eb1aec5250..4be5d27f93b17c28cb947fec815f2ac2c138efc3 100644
--- a/applications/Chat/coati/trainer/__init__.py
+++ b/applications/Chat/coati/trainer/__init__.py
@@ -1,6 +1,6 @@
-from .base import Trainer
+from .base import OnPolicyTrainer, SLTrainer
from .ppo import PPOTrainer
from .rm import RewardModelTrainer
from .sft import SFTTrainer
-__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer', 'SFTTrainer']
+__all__ = ["SLTrainer", "OnPolicyTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer"]
diff --git a/applications/Chat/coati/trainer/base.py b/applications/Chat/coati/trainer/base.py
index ac3a878be88430f94af1bf19bcde4aaf1d0ec7f2..0a41d450d41ef52553f1d036fff8593271630a69 100644
--- a/applications/Chat/coati/trainer/base.py
+++ b/applications/Chat/coati/trainer/base.py
@@ -1,54 +1,106 @@
from abc import ABC, abstractmethod
-from typing import Any, Callable, Dict, List, Optional, Union
+from contextlib import contextmanager
+from typing import List
-import torch
+import torch.nn as nn
+import tqdm
+from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import Experience
+from torch.optim import Optimizer
from .callbacks import Callback
from .strategies import Strategy
+from .utils import is_rank_0
-class Trainer(ABC):
+class SLTrainer(ABC):
"""
- Base class for rlhf trainers.
+ Base class for supervised learning trainers.
Args:
strategy (Strategy):the strategy to use for training
max_epochs (int, defaults to 1): the number of epochs of training process
+ model (nn.Module): the model to train
+ optim (Optimizer): the optimizer to use for training
+ """
+
+ def __init__(
+ self,
+ strategy: Strategy,
+ max_epochs: int,
+ model: nn.Module,
+ optimizer: Optimizer,
+ ) -> None:
+ super().__init__()
+ self.strategy = strategy
+ self.max_epochs = max_epochs
+ self.model = model
+ self.optimizer = optimizer
+
+ @abstractmethod
+ def _train(self, epoch):
+ raise NotImplementedError()
+
+ @abstractmethod
+ def _eval(self, epoch):
+ raise NotImplementedError()
+
+ def _before_fit(self):
+ raise NotImplementedError()
+
+ def fit(self, *args, **kwargs):
+ self._before_fit(*args, **kwargs)
+ for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0()):
+ self._train(epoch)
+ self._eval(epoch)
+
+
+class OnPolicyTrainer(ABC):
+ """
+ Base class for on-policy rl trainers, e.g. PPO.
+
+ Args:
+ strategy (Strategy):the strategy to use for training
+ data_buffer (NaiveExperienceBuffer): the buffer to collect experiences
+ sample_buffer (bool, defaults to False): whether to sample from buffer
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
- generate_kwargs (dict, optional): the kwargs to use while model generating
"""
- def __init__(self,
- strategy: Strategy,
- max_epochs: int = 1,
- dataloader_pin_memory: bool = True,
- callbacks: List[Callback] = [],
- **generate_kwargs) -> None:
+ def __init__(
+ self,
+ strategy: Strategy,
+ data_buffer: NaiveExperienceBuffer,
+ sample_buffer: bool,
+ dataloader_pin_memory: bool,
+ callbacks: List[Callback] = [],
+ ) -> None:
super().__init__()
self.strategy = strategy
- self.max_epochs = max_epochs
- self.generate_kwargs = generate_kwargs
+ self.data_buffer = data_buffer
+ self.sample_buffer = sample_buffer
self.dataloader_pin_memory = dataloader_pin_memory
self.callbacks = callbacks
- # TODO(ver217): maybe simplify these code using context
- def _on_fit_start(self) -> None:
+ @contextmanager
+ def _fit_ctx(self) -> None:
for callback in self.callbacks:
callback.on_fit_start()
-
- def _on_fit_end(self) -> None:
- for callback in self.callbacks:
- callback.on_fit_end()
-
- def _on_episode_start(self, episode: int) -> None:
+ try:
+ yield
+ finally:
+ for callback in self.callbacks:
+ callback.on_fit_end()
+
+ @contextmanager
+ def _episode_ctx(self, episode: int) -> None:
for callback in self.callbacks:
callback.on_episode_start(episode)
-
- def _on_episode_end(self, episode: int) -> None:
- for callback in self.callbacks:
- callback.on_episode_end(episode)
+ try:
+ yield
+ finally:
+ for callback in self.callbacks:
+ callback.on_episode_end(episode)
def _on_make_experience_start(self) -> None:
for callback in self.callbacks:
@@ -70,6 +122,67 @@ class Trainer(ABC):
for callback in self.callbacks:
callback.on_learn_batch_start()
- def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
+ def _on_learn_batch_end(self, experience: Experience) -> None:
for callback in self.callbacks:
- callback.on_learn_batch_end(metrics, experience)
+ callback.on_learn_batch_end(experience)
+
+ @abstractmethod
+ def _make_experience(self, collect_step: int):
+ """
+ Implement this method to make experience.
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def _learn(self, update_step: int):
+ """
+ Implement this method to learn from experience, either
+ sample from buffer or transform buffer into dataloader.
+ """
+ raise NotImplementedError()
+
+ def _collect_phase(self, collect_step: int):
+ self._on_make_experience_start()
+ experience = self._make_experience(collect_step)
+ self._on_make_experience_end(experience)
+ self.data_buffer.append(experience)
+
+ def _update_phase(self, update_step: int):
+ self._on_learn_epoch_start(update_step)
+ self._learn(update_step)
+ self._on_learn_epoch_end(update_step)
+
+ def _before_fit(self, *args, **kwargs):
+ raise NotImplementedError()
+
+ def fit(
+ self,
+ num_episodes: int,
+ num_collect_steps: int,
+ num_update_steps: int,
+ *args,
+ **kwargs,
+ ):
+ """
+ The main training loop of on-policy rl trainers.
+
+ Args:
+ num_episodes (int): the number of episodes to train
+ num_collect_steps (int): the number of collect steps per episode
+ num_update_steps (int): the number of update steps per episode
+ """
+ self._before_fit(*args, **kwargs)
+ with self._fit_ctx():
+ for episode in tqdm.trange(num_episodes, desc="Episodes", disable=not is_rank_0()):
+ with self._episode_ctx(episode):
+ for collect_step in tqdm.trange(num_collect_steps, desc="Collect steps", disable=not is_rank_0()):
+ self._collect_phase(collect_step)
+ if not self.sample_buffer:
+ # HACK(cwher): according to the design of boost API, dataloader should also be boosted,
+ # but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
+ # I only call strategy.setup_dataloader() to setup dataloader.
+ self.dataloader = self.strategy.setup_dataloader(self.data_buffer, self.dataloader_pin_memory)
+ for update_step in tqdm.trange(num_update_steps, desc="Update steps", disable=not is_rank_0()):
+ self._update_phase(update_step)
+ # NOTE: this is for on-policy algorithms
+ self.data_buffer.clear()
diff --git a/applications/Chat/coati/trainer/callbacks/__init__.py b/applications/Chat/coati/trainer/callbacks/__init__.py
index 9ed0ee6f764002fd085f0ecdf957f2dba4d576c8..29c8c4f00a5c62f1888b82206c72a8efadfd28b3 100644
--- a/applications/Chat/coati/trainer/callbacks/__init__.py
+++ b/applications/Chat/coati/trainer/callbacks/__init__.py
@@ -2,4 +2,4 @@ from .base import Callback
from .performance_evaluator import PerformanceEvaluator
from .save_checkpoint import SaveCheckpoint
-__all__ = ['Callback', 'PerformanceEvaluator', 'SaveCheckpoint']
+__all__ = ["Callback", "PerformanceEvaluator", "SaveCheckpoint"]
diff --git a/applications/Chat/coati/trainer/callbacks/base.py b/applications/Chat/coati/trainer/callbacks/base.py
index f5616048855b26ceac836e812d1ecdee8035f025..c6e30f04885ccc86b42366f8eca5b65d8ed62868 100644
--- a/applications/Chat/coati/trainer/callbacks/base.py
+++ b/applications/Chat/coati/trainer/callbacks/base.py
@@ -5,7 +5,7 @@ from coati.experience_maker import Experience
class Callback(ABC):
"""
- Base callback class. It defines the interface for callbacks.
+ Base callback class. It defines the interface for callbacks.
"""
def on_fit_start(self) -> None:
@@ -35,5 +35,5 @@ class Callback(ABC):
def on_learn_batch_start(self) -> None:
pass
- def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
+ def on_learn_batch_end(self, experience: Experience) -> None:
pass
diff --git a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py b/applications/Chat/coati/trainer/callbacks/performance_evaluator.py
index 925455444597095d67330698cffaca151f8c5e0b..b286c766c2637860a3f0d1fedb5abe9216aa5d6f 100644
--- a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py
+++ b/applications/Chat/coati/trainer/callbacks/performance_evaluator.py
@@ -21,9 +21,9 @@ def print_rank_0(*args, **kwargs) -> None:
def divide(x: float, y: float) -> float:
if y == 0:
- return float('inf')
- elif y == float('inf'):
- return float('nan')
+ return float("inf")
+ elif y == float("inf"):
+ return float("nan")
return x / y
@@ -38,10 +38,9 @@ def all_reduce_mean(x: float, world_size: int) -> float:
class Timer:
-
def __init__(self) -> None:
self.start_time: Optional[float] = None
- self.duration: float = 0.
+ self.duration: float = 0.0
def start(self) -> None:
self.start_time = time()
@@ -52,7 +51,7 @@ class Timer:
self.start_time = None
def reset(self) -> None:
- self.duration = 0.
+ self.duration = 0.0
class PerformanceEvaluator(Callback):
@@ -67,13 +66,15 @@ class PerformanceEvaluator(Callback):
ignore_episodes: The number of episodes to ignore when calculating the performance.
"""
- def __init__(self,
- actor_num_params: int,
- critic_num_params: int,
- initial_model_num_params: int,
- reward_model_num_params: int,
- enable_grad_checkpoint: bool = False,
- ignore_episodes: int = 0) -> None:
+ def __init__(
+ self,
+ actor_num_params: int,
+ critic_num_params: int,
+ initial_model_num_params: int,
+ reward_model_num_params: int,
+ enable_grad_checkpoint: bool = False,
+ ignore_episodes: int = 0,
+ ) -> None:
super().__init__()
self.world_size = get_world_size()
self.actor_num_params = actor_num_params
@@ -136,7 +137,7 @@ class PerformanceEvaluator(Callback):
return
self.learn_timer.start()
- def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
+ def on_learn_batch_end(self, experience: Experience) -> None:
if self.disable:
return
self.learn_timer.end()
@@ -155,8 +156,9 @@ class PerformanceEvaluator(Callback):
avg_learn_duration = all_reduce_mean(self.learn_timer.duration, self.world_size)
avg_overall_duration = all_reduce_mean(self.overall_timer.duration, self.world_size)
- avg_make_experience_throughput = self.make_experience_num_samples * \
- self.world_size / (avg_make_experience_duration + 1e-12)
+ avg_make_experience_throughput = (
+ self.make_experience_num_samples * self.world_size / (avg_make_experience_duration + 1e-12)
+ )
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
avg_learn_throughput = self.learn_num_samples * self.world_size / (avg_learn_duration + 1e-12)
@@ -171,13 +173,11 @@ class PerformanceEvaluator(Callback):
learn_time_per_sample = divide(avg_learn_duration, num_effective_samples)
print_rank_0(
- f'Performance summary:\n' +
- f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n'
- +
- f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n'
- + f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n' +
- f'Overall time per sample: {overall_time_per_sample:.2f} s\n' +
- f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n'
- +
- f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%'
+ f"Performance summary:\n"
+ + f"Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n"
+ + f"Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n"
+ + f"Overall throughput: {avg_overall_throughput:.2f} samples/s\n"
+ + f"Overall time per sample: {overall_time_per_sample:.2f} s\n"
+ + f"Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n"
+ + f"Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%"
)
diff --git a/applications/Chat/coati/trainer/callbacks/save_checkpoint.py b/applications/Chat/coati/trainer/callbacks/save_checkpoint.py
index d2dcc0dd4c65f0aae86a87b188cc2801fa0e5281..0d70b6c53073a5e14a57a18a38818416d9db6daa 100644
--- a/applications/Chat/coati/trainer/callbacks/save_checkpoint.py
+++ b/applications/Chat/coati/trainer/callbacks/save_checkpoint.py
@@ -1,7 +1,7 @@
import os
import torch.distributed as dist
-from coati.trainer.strategies import ColossalAIStrategy, Strategy
+from coati.trainer.strategies import GeminiStrategy, LowLevelZeroStrategy, Strategy
from coati.trainer.utils import is_rank_0
from torch import nn
from torch.optim import Optimizer
@@ -36,40 +36,41 @@ class SaveCheckpoint(Callback):
"""
- def __init__(self,
- path: str,
- interval: int,
- strategy: Strategy,
- actor: nn.Module = None,
- critic: nn.Module = None,
- actor_optim: Optimizer = None,
- critic_optim: Optimizer = None) -> None:
+ def __init__(
+ self,
+ path: str,
+ interval: int,
+ strategy: Strategy,
+ actor: nn.Module = None,
+ critic: nn.Module = None,
+ actor_optim: Optimizer = None,
+ critic_optim: Optimizer = None,
+ ) -> None:
super().__init__()
- self.path = os.path.join(path, 'checkpoint')
+ self.path = os.path.join(path, "checkpoint")
self.interval = interval
self.strategy = strategy
- self.model_dict = {'actor': [actor, actor_optim], 'critic': [critic, critic_optim]}
+ self.model_dict = {"actor": [actor, actor_optim], "critic": [critic, critic_optim]}
def on_episode_end(self, episode: int) -> None:
if (episode + 1) % self.interval != 0:
return
- base_path = os.path.join(self.path, f'episode_{episode}')
+ base_path = os.path.join(self.path, f"episode_{episode}")
if not os.path.exists(base_path):
os.makedirs(base_path)
for model in self.model_dict.keys():
-
# save model
if self.model_dict[model][0] is None:
# saving only optimizer states is meaningless, so it would be skipped
continue
- model_path = os.path.join(base_path, f'{model}.pt')
+ model_path = os.path.join(base_path, f"{model}.pt")
self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True)
# save optimizer
if self.model_dict[model][1] is None:
continue
- only_rank0 = not isinstance(self.strategy, ColossalAIStrategy)
+ only_rank0 = not isinstance(self.strategy, (LowLevelZeroStrategy, GeminiStrategy))
rank = 0 if is_rank_0() else dist.get_rank()
- optim_path = os.path.join(base_path, f'{model}-optim-rank-{rank}.pt')
+ optim_path = os.path.join(base_path, f"{model}-optim-rank-{rank}.pt")
self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0)
diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py
index fe5ae48d9c2f05af14235a0106c43e74959e4c11..d6966689885ec2760247c211ef51dc6813e447fe 100644
--- a/applications/Chat/coati/trainer/ppo.py
+++ b/applications/Chat/coati/trainer/ppo.py
@@ -1,26 +1,38 @@
-from typing import Any, Callable, Dict, List, Optional, Union
+from typing import Dict, List, Optional
-import torch
-import torch.nn as nn
+from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import Experience, NaiveExperienceMaker
-from coati.models.base import Actor, Critic
+from coati.models.base import Actor, Critic, RewardModel, get_base_model
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
-from coati.replay_buffer import NaiveReplayBuffer
-from torch import Tensor
+from coati.models.utils import calc_action_log_probs
from torch.optim import Optimizer
-from torch.utils.data import DistributedSampler
+from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
-from transformers.tokenization_utils_base import PreTrainedTokenizerBase
+from transformers import PreTrainedTokenizerBase
from colossalai.utils import get_current_device
-from .base import Trainer
+from .base import OnPolicyTrainer
from .callbacks import Callback
-from .strategies import Strategy
-from .utils import is_rank_0, to_device
+from .strategies import GeminiStrategy, Strategy
+from .utils import CycledDataLoader, is_rank_0, to_device
-class PPOTrainer(Trainer):
+def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
+ unwrapped_model = strategy.unwrap_model(actor)
+ hf_model = get_base_model(unwrapped_model)
+ new_kwargs = {**generate_kwargs}
+ # use huggingface models method directly
+ if "prepare_inputs_fn" not in generate_kwargs and hasattr(hf_model, "prepare_inputs_for_generation"):
+ new_kwargs["prepare_inputs_fn"] = hf_model.prepare_inputs_for_generation
+
+ if "update_model_kwargs_fn" not in generate_kwargs and hasattr(hf_model, "_update_model_kwargs_for_generation"):
+ new_kwargs["update_model_kwargs_fn"] = hf_model._update_model_kwargs_for_generation
+
+ return new_kwargs
+
+
+class PPOTrainer(OnPolicyTrainer):
"""
Trainer for PPO algorithm.
@@ -28,60 +40,61 @@ class PPOTrainer(Trainer):
strategy (Strategy): the strategy to use for training
actor (Actor): the actor model in ppo algorithm
critic (Critic): the critic model in ppo algorithm
- reward_model (nn.Module): the reward model in rlhf algorithm to make reward of sentences
+ reward_model (RewardModel): the reward model in rlhf algorithm to make reward of sentences
initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor
actor_optim (Optimizer): the optimizer to use for actor model
critic_optim (Optimizer): the optimizer to use for critic model
kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
train_batch_size (int, defaults to 8): the batch size to use for training
- buffer_limit (int, defaults to 0): the max_size limitation of replay buffer
- buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
+ buffer_limit (int, defaults to 0): the max_size limitation of buffer
+ buffer_cpu_offload (bool, defaults to True): whether to offload buffer to cpu
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
vf_coef (float, defaults to 1.0): the coefficient of value loss
ptx_coef (float, defaults to 0.9): the coefficient of ptx loss
value_clip (float, defaults to 0.4): the clip coefficient of value loss
- max_epochs (int, defaults to 1): the number of epochs of training process
- sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
+ sample_buffer (bool, defaults to False): whether to sample from buffer
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
"""
- def __init__(self,
- strategy: Strategy,
- actor: Actor,
- critic: Critic,
- reward_model: nn.Module,
- initial_model: Actor,
- actor_optim: Optimizer,
- critic_optim: Optimizer,
- kl_coef: float = 0.1,
- ptx_coef: float = 0.9,
- train_batch_size: int = 8,
- buffer_limit: int = 0,
- buffer_cpu_offload: bool = True,
- eps_clip: float = 0.2,
- vf_coef: float = 1.0,
- value_clip: float = 0.4,
- max_epochs: int = 1,
- sample_replay_buffer: bool = False,
- dataloader_pin_memory: bool = True,
- offload_inference_models: bool = True,
- callbacks: List[Callback] = [],
- **generate_kwargs) -> None:
- experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
- replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
- generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
- super().__init__(strategy, max_epochs, dataloader_pin_memory, callbacks, **generate_kwargs)
-
- self.experience_maker = experience_maker
- self.replay_buffer = replay_buffer
- self.sample_replay_buffer = sample_replay_buffer
- self.offload_inference_models = offload_inference_models
+ def __init__(
+ self,
+ strategy: Strategy,
+ actor: Actor,
+ critic: Critic,
+ reward_model: RewardModel,
+ initial_model: Actor,
+ actor_optim: Optimizer,
+ critic_optim: Optimizer,
+ tokenizer: PreTrainedTokenizerBase,
+ kl_coef: float = 0.1,
+ ptx_coef: float = 0.9,
+ train_batch_size: int = 8,
+ buffer_limit: int = 0,
+ buffer_cpu_offload: bool = True,
+ eps_clip: float = 0.2,
+ vf_coef: float = 1.0,
+ value_clip: float = 0.4,
+ sample_buffer: bool = False,
+ dataloader_pin_memory: bool = True,
+ offload_inference_models: bool = True,
+ callbacks: List[Callback] = [],
+ **generate_kwargs,
+ ) -> None:
+ if isinstance(strategy, GeminiStrategy):
+ assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')"
+
+ data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
+ super().__init__(strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks)
+
+ self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
+ self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer, kl_coef)
self.actor = actor
self.critic = critic
+ self.tokenizer = tokenizer
self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip)
@@ -91,123 +104,99 @@ class PPOTrainer(Trainer):
self.actor_optim = actor_optim
self.critic_optim = critic_optim
+ self.offload_inference_models = offload_inference_models
self.device = get_current_device()
- def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
- if isinstance(inputs, Tensor):
- return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
- elif isinstance(inputs, dict):
- return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
- else:
- raise ValueError(f'Unsupported input type "{type(inputs)}"')
-
- def _learn(self):
- # replay buffer may be empty at first, we should rebuild at each training
- if not self.sample_replay_buffer:
- dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory)
- if self.sample_replay_buffer:
- pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
- for _ in pbar:
- experience = self.replay_buffer.sample()
- experience.to_device(self.device)
- metrics = self.training_step(experience)
- pbar.set_postfix(metrics)
- else:
- for epoch in range(self.max_epochs):
- self._on_learn_epoch_start(epoch)
- if isinstance(dataloader.sampler, DistributedSampler):
- dataloader.sampler.set_epoch(epoch)
- pbar = tqdm(dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0())
- for experience in pbar:
- self._on_learn_batch_start()
- experience.to_device(self.device)
- metrics = self.training_step(experience)
- self._on_learn_batch_end(metrics, experience)
- pbar.set_postfix(metrics)
- self._on_learn_epoch_end(epoch)
-
- def fit(self,
- prompt_dataloader,
- pretrain_dataloader,
- num_episodes: int = 50000,
- max_timesteps: int = 500,
- update_timesteps: int = 5000) -> None:
- time = 0
- self.pretrain_dataloader = pretrain_dataloader
- self.prompt_dataloader = prompt_dataloader
- self._on_fit_start()
- for episode in range(num_episodes):
- self._on_episode_start(episode)
- for timestep in tqdm(range(max_timesteps),
- desc=f'Episode [{episode+1}/{num_episodes}]',
- disable=not is_rank_0()):
- time += 1
- prompts = next(iter(self.prompt_dataloader))
- self._on_make_experience_start()
- if self.offload_inference_models:
- # TODO(ver217): this may be controlled by strategy if they are prepared by strategy
- self.experience_maker.initial_model.to(self.device)
- self.experience_maker.reward_model.to(self.device)
- experience = self._make_experience(prompts)
- self._on_make_experience_end(experience)
- self.replay_buffer.append(experience)
- if time % update_timesteps == 0:
- if self.offload_inference_models:
- self.experience_maker.initial_model.to('cpu')
- self.experience_maker.reward_model.to('cpu')
- self._learn()
- self.replay_buffer.clear()
- self._on_episode_end(episode)
- self._on_fit_end()
-
- def training_step(self, experience: Experience) -> Dict[str, float]:
+ def _before_fit(
+ self,
+ prompt_dataloader: DataLoader,
+ pretrain_dataloader: DataLoader,
+ log_dir: Optional[str] = None,
+ use_wandb: bool = False,
+ ):
+ """
+ Args:
+ prompt_dataloader (DataLoader): the dataloader to use for prompt data
+ pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
+ """
+ self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
+ self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)
+
+ self.writer = None
+ if use_wandb and is_rank_0():
+ assert log_dir is not None, "log_dir must be provided when use_wandb is True"
+ import wandb
+
+ wandb.init(project="Coati-ppo", sync_tensorboard=True)
+ if log_dir is not None and is_rank_0():
+ import os
+ import time
+
+ from torch.utils.tensorboard import SummaryWriter
+
+ log_dir = os.path.join(log_dir, "ppo")
+ log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
+ self.writer = SummaryWriter(log_dir=log_dir)
+
+ def _make_experience(self, collect_step: int) -> Experience:
+ prompts = self.prompt_dataloader.next()
+ if self.offload_inference_models:
+ # TODO(ver217): this may be controlled by strategy if they are prepared by strategy
+ self.experience_maker.initial_model.to(self.device)
+ self.experience_maker.reward_model.to(self.device)
+ assert isinstance(prompts, dict), f'Unsupported input type "{type(prompts)}"'
+ return self.experience_maker.make_experience(**prompts, **self.generate_kwargs)
+
+ def _training_step(self, experience: Experience):
self.actor.train()
self.critic.train()
# policy loss
- num_actions = experience.action_mask.size(1)
- action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
- actor_loss = self.actor_loss_fn(action_log_probs,
- experience.action_log_probs,
- experience.advantages,
- action_mask=experience.action_mask)
+ num_actions = experience.action_log_probs.size(1)
+ actor_logits = self.actor(experience.sequences, experience.attention_mask)["logits"]
+ action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)
+ actor_loss = self.actor_loss_fn(
+ action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
+ )
+ actor_loss = (1 - self.ptx_coef) * actor_loss
+ self.strategy.backward(actor_loss, self.actor, self.actor_optim)
# ptx loss
if self.ptx_coef != 0:
- batch = next(iter(self.pretrain_dataloader))
+ batch = self.pretrain_dataloader.next()
batch = to_device(batch, self.device)
- ptx_log_probs = self.actor.get_base_model()(batch['input_ids'],
- attention_mask=batch['attention_mask'])['logits']
- ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels'])
- actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)
+ ptx_log_probs = self.actor(batch["input_ids"], batch["attention_mask"])["logits"]
+ ptx_loss = self.ptx_coef * self.ptx_loss_fn(ptx_log_probs, batch["labels"])
+ self.strategy.backward(ptx_loss, self.actor, self.actor_optim)
- self.strategy.backward(actor_loss, self.actor, self.actor_optim)
self.strategy.optimizer_step(self.actor_optim)
self.actor_optim.zero_grad()
# value loss
- values = self.critic(experience.sequences,
- action_mask=experience.action_mask,
- attention_mask=experience.attention_mask)
- critic_loss = self.critic_loss_fn(values,
- experience.values,
- experience.reward,
- action_mask=experience.action_mask)
+ values = self.critic(experience.sequences, attention_mask=experience.attention_mask)
+ critic_loss = self.critic_loss_fn(values, experience.values, experience.reward)
critic_loss = critic_loss * self.vf_coef
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
self.strategy.optimizer_step(self.critic_optim)
self.critic_optim.zero_grad()
- return {'reward': experience.reward.mean().item()}
-
-
-def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
- origin_model = strategy.unwrap_model(actor)
- new_kwargs = {**generate_kwargs}
- # use huggingface models method directly
- if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
- new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
-
- if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(origin_model, '_update_model_kwargs_for_generation'):
- new_kwargs['update_model_kwargs_fn'] = origin_model._update_model_kwargs_for_generation
-
- return new_kwargs
+ def _learn(self, update_step: int):
+ if self.offload_inference_models:
+ self.experience_maker.initial_model.to("cpu")
+ self.experience_maker.reward_model.to("cpu")
+
+ # buffer may be empty at first, we should rebuild at each training
+ if self.sample_buffer:
+ experience = self.data_buffer.sample()
+ self._on_learn_batch_start()
+ experience.to_device(self.device)
+ self._training_step(experience)
+ self._on_learn_batch_end(experience)
+ else:
+ if isinstance(self.dataloader.sampler, DistributedSampler):
+ self.dataloader.sampler.set_epoch(update_step)
+ pbar = tqdm(self.dataloader, desc=f"Train epoch [{update_step + 1}]", disable=not is_rank_0())
+ for experience in pbar:
+ self._on_learn_batch_start()
+ experience.to_device(self.device)
+ self._training_step(experience)
+ self._on_learn_batch_end(experience)
diff --git a/applications/Chat/coati/trainer/rm.py b/applications/Chat/coati/trainer/rm.py
index cdae5108ab00e3506449d8c837f62b2d594d8d31..d7f8c21a5a3d2264ce553b09139f73d761c204e4 100644
--- a/applications/Chat/coati/trainer/rm.py
+++ b/applications/Chat/coati/trainer/rm.py
@@ -1,35 +1,27 @@
-from datetime import datetime
-from typing import List, Optional
+from typing import Callable, Optional
-import pandas as pd
import torch
-import torch.distributed as dist
-from torch.optim import Optimizer, lr_scheduler
-from torch.utils.data import DataLoader, Dataset, DistributedSampler
-from tqdm import tqdm
-from transformers.tokenization_utils_base import PreTrainedTokenizerBase
-
-from .base import Trainer
-from .callbacks import Callback
+import tqdm
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+from torch.utils.data import DataLoader
+
+from .base import SLTrainer
from .strategies import Strategy
from .utils import is_rank_0
-class RewardModelTrainer(Trainer):
+class RewardModelTrainer(SLTrainer):
"""
Trainer to use while training reward model.
Args:
model (torch.nn.Module): the model to train
strategy (Strategy): the strategy to use for training
- optim(Optimizer): the optimizer to use for training
+ optim (Optimizer): the optimizer to use for training
+ lr_scheduler (_LRScheduler): the lr scheduler to use for training
loss_fn (callable): the loss function to use for training
- train_dataloader (DataLoader): the dataloader to use for training
- valid_dataloader (DataLoader): the dataloader to use for validation
- eval_dataloader (DataLoader): the dataloader to use for evaluation
- batch_size (int, defaults to 1): the batch size while training
max_epochs (int, defaults to 2): the number of epochs to train
- callbacks (List[Callback], defaults to []): the callbacks to call during training process
"""
def __init__(
@@ -37,87 +29,95 @@ class RewardModelTrainer(Trainer):
model,
strategy: Strategy,
optim: Optimizer,
- loss_fn,
- train_dataloader: DataLoader,
- valid_dataloader: DataLoader,
- eval_dataloader: DataLoader,
+ lr_scheduler: _LRScheduler,
+ loss_fn: Callable,
max_epochs: int = 1,
- callbacks: List[Callback] = [],
) -> None:
- super().__init__(strategy, max_epochs, callbacks=callbacks)
+ super().__init__(strategy, max_epochs, model, optim)
+
+ self.loss_fn = loss_fn
+ self.scheduler = lr_scheduler
+
+ self.num_train_step = 0
+
+ def _eval(self, epoch):
+ if self.eval_dataloader is not None:
+ self.model.eval()
+ dist, num_correct, num_samples = 0, 0, 0
+ with torch.no_grad():
+ for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader:
+ chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
+ c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
+ reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
+ r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
+ chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
+ reject_reward = self.model(reject_ids, attention_mask=r_mask)
+ num_samples += chosen_ids.size(0)
+ num_correct += (chosen_reward > reject_reward).sum().item()
+ dist += (chosen_reward - reject_reward).mean().item()
+ self.dist = dist / len(self.eval_dataloader)
+ self.acc = num_correct / num_samples
+
+ if self.writer:
+ self.writer.add_scalar("eval/dist", self.dist, epoch)
+ self.writer.add_scalar("eval/acc", self.acc, epoch)
+
+ def _train(self, epoch):
+ self.model.train()
+ step_bar = tqdm.trange(
+ len(self.train_dataloader), desc=f"Epoch {epoch + 1}/{self.max_epochs}", disable=not is_rank_0()
+ )
+ for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
+ chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
+ c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
+ reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
+ r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
+ chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
+ reject_reward = self.model(reject_ids, attention_mask=r_mask)
+ loss = self.loss_fn(chosen_reward, reject_reward)
+ self.strategy.backward(loss, self.model, self.optimizer)
+ self.strategy.optimizer_step(self.optimizer)
+ self.optimizer.zero_grad()
+ if self.writer:
+ self.writer.add_scalar("train/loss", loss.item(), self.num_train_step)
+ self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
+ self.writer.add_scalar("train/dist", (chosen_reward - reject_reward).mean().item(), self.num_train_step)
+ self.writer.add_scalar(
+ "train/acc", (chosen_reward > reject_reward).float().mean().item(), self.num_train_step
+ )
+ self.num_train_step += 1
+ if self.num_train_step % 100 == 0:
+ self.scheduler.step()
+ step_bar.update()
+ step_bar.close()
+ def _before_fit(
+ self,
+ train_dataloader: DataLoader,
+ eval_dataloader: DataLoader,
+ log_dir: Optional[str] = None,
+ use_wandb: bool = False,
+ ):
+ """
+ Args:
+ train_dataloader (DataLoader): the dataloader to use for training
+ eval_dataloader (DataLoader): the dataloader to use for evaluation
+ """
self.train_dataloader = train_dataloader
- self.valid_dataloader = valid_dataloader
self.eval_dataloader = eval_dataloader
- self.model = model
- self.loss_fn = loss_fn
- self.optimizer = optim
- self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, self.train_dataloader.__len__() // 100)
-
- def eval_acc(self, dataloader):
- dist = 0
- on = 0
- cnt = 0
- self.model.eval()
- with torch.no_grad():
- for chosen_ids, c_mask, reject_ids, r_mask in dataloader:
- chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
- c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
- reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
- r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
- chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
- reject_reward = self.model(reject_ids, attention_mask=r_mask)
- for i in range(len(chosen_reward)):
- cnt += 1
- if chosen_reward[i] > reject_reward[i]:
- on += 1
- dist += (chosen_reward - reject_reward).mean().item()
- dist_mean = dist / len(dataloader)
- acc = on / cnt
- self.model.train()
- return dist_mean, acc
-
- def fit(self):
- time = datetime.now()
- epoch_bar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
- for epoch in range(self.max_epochs):
- step_bar = tqdm(range(self.train_dataloader.__len__()),
- desc='Train step of epoch %d' % epoch,
- disable=not is_rank_0())
- # train
- self.model.train()
- cnt = 0
- acc = 0
- dist = 0
- for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
- chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
- c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
- reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
- r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
- chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
- reject_reward = self.model(reject_ids, attention_mask=r_mask)
- loss = self.loss_fn(chosen_reward, reject_reward)
- self.strategy.backward(loss, self.model, self.optimizer)
- self.strategy.optimizer_step(self.optimizer)
- self.optimizer.zero_grad()
- cnt += 1
- if cnt == 100:
- self.scheduler.step()
- dist, acc = self.eval_acc(self.valid_dataloader)
- cnt = 0
- if is_rank_0():
- log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]],
- columns=['step', 'loss', 'dist', 'acc'])
- log.to_csv('log_%s.csv' % time, mode='a', header=False, index=False)
- step_bar.update()
- step_bar.set_postfix({'dist': dist, 'acc': acc})
-
- # eval
- dist, acc = self.eval_acc(self.eval_dataloader)
- if is_rank_0():
- log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc'])
- log.to_csv('log.csv', mode='a', header=False, index=False)
- epoch_bar.update()
- step_bar.set_postfix({'dist': dist, 'acc': acc})
- step_bar.close()
+ self.writer = None
+ if use_wandb and is_rank_0():
+ assert log_dir is not None, "log_dir must be provided when use_wandb is True"
+ import wandb
+
+ wandb.init(project="Coati-rm", sync_tensorboard=True)
+ if log_dir is not None and is_rank_0():
+ import os
+ import time
+
+ from torch.utils.tensorboard import SummaryWriter
+
+ log_dir = os.path.join(log_dir, "rm")
+ log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
+ self.writer = SummaryWriter(log_dir=log_dir)
diff --git a/applications/Chat/coati/trainer/sft.py b/applications/Chat/coati/trainer/sft.py
index 63fde53956ccd387296cdfde6068f93d76d4fd3f..7d0eeec897e573daa2324a3f63f1d83988946425 100644
--- a/applications/Chat/coati/trainer/sft.py
+++ b/applications/Chat/coati/trainer/sft.py
@@ -1,23 +1,20 @@
-import math
-import time
-from typing import List, Optional
+from typing import Optional
import torch
import torch.distributed as dist
-import wandb
+import tqdm
from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
-from tqdm import tqdm
-from transformers.tokenization_utils_base import PreTrainedTokenizerBase
-from transformers.trainer import get_scheduler
-from .base import Trainer
-from .callbacks import Callback
-from .strategies import ColossalAIStrategy, Strategy
+from colossalai.logging import DistributedLogger
+
+from .base import SLTrainer
+from .strategies import GeminiStrategy, Strategy
from .utils import is_rank_0, to_device
-class SFTTrainer(Trainer):
+class SFTTrainer(SLTrainer):
"""
Trainer to use while training reward model.
@@ -25,12 +22,9 @@ class SFTTrainer(Trainer):
model (torch.nn.Module): the model to train
strategy (Strategy): the strategy to use for training
optim(Optimizer): the optimizer to use for training
- train_dataloader: the dataloader to use for training
- eval_dataloader: the dataloader to use for evaluation
- batch_size (int, defaults to 1): the batch size while training
+ lr_scheduler(_LRScheduler): the lr scheduler to use for training
max_epochs (int, defaults to 2): the number of epochs to train
- callbacks (List[Callback], defaults to []): the callbacks to call during training process
- optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
+ accumulation_steps (int, defaults to 8): the number of steps to accumulate gradients
"""
def __init__(
@@ -38,98 +32,99 @@ class SFTTrainer(Trainer):
model,
strategy: Strategy,
optim: Optimizer,
- train_dataloader: DataLoader,
- eval_dataloader: DataLoader = None,
+ lr_scheduler: _LRScheduler,
max_epochs: int = 2,
accumulation_steps: int = 8,
- callbacks: List[Callback] = [],
) -> None:
- if accumulation_steps > 1 and isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3:
- raise ValueError("Accumulation steps are not supported in stage 3 of ColossalAI")
- super().__init__(strategy, max_epochs, callbacks=callbacks)
+ if accumulation_steps > 1:
+ assert not isinstance(
+ strategy, GeminiStrategy
+ ), "Accumulation steps are not supported in stage 3 of ColossalAI"
+
+ super().__init__(strategy, max_epochs, model, optim)
+
+ self.accumulation_steps = accumulation_steps
+ self.scheduler = lr_scheduler
+
+ self.num_train_step = 0
+ self.num_eval_step = 0
+
+ def _train(self, epoch: int):
+ self.model.train()
+ step_bar = tqdm.trange(
+ len(self.train_dataloader) // self.accumulation_steps,
+ desc=f"Epoch {epoch + 1}/{self.max_epochs}",
+ disable=not is_rank_0(),
+ )
+ for i, batch in enumerate(self.train_dataloader):
+ batch = to_device(batch, torch.cuda.current_device())
+ outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
+ loss = outputs.loss / self.accumulation_steps
+ self.total_loss += loss.item()
+ self.strategy.backward(loss, self.model, self.optimizer)
+ # gradient accumulation
+ if (i + 1) % self.accumulation_steps == 0:
+ self.strategy.optimizer_step(self.optimizer)
+ self.optimizer.zero_grad()
+ self.scheduler.step()
+ if self.writer:
+ self.writer.add_scalar("train/loss", self.total_loss, self.num_train_step)
+ self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
+ self.num_train_step += 1
+ self.total_loss = 0
+ step_bar.update()
+ step_bar.close()
+
+ def _eval(self, epoch: int):
+ if self.eval_dataloader is not None:
+ self.model.eval()
+ with torch.no_grad():
+ loss_sum, num_seen = 0, 0
+ for batch in self.eval_dataloader:
+ batch = to_device(batch, torch.cuda.current_device())
+ outputs = self.model(
+ batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]
+ )
+ loss_sum += outputs.loss.item()
+ num_seen += batch["input_ids"].size(0)
+ loss_mean = loss_sum / num_seen
+ if dist.get_rank() == 0:
+ self.logger.info(f"Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}")
+ if self.writer:
+ self.writer.add_scalar("eval/loss", loss_mean, self.num_eval_step)
+ self.num_eval_step += 1
+
+ def _before_fit(
+ self,
+ train_dataloader: DataLoader,
+ eval_dataloader: Optional[DataLoader] = None,
+ logger: Optional[DistributedLogger] = None,
+ log_dir: Optional[str] = None,
+ use_wandb: bool = False,
+ ):
+ """
+ Args:
+ train_dataloader: the dataloader to use for training
+ eval_dataloader: the dataloader to use for evaluation
+ """
self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader
- self.model = model
- self.optimizer = optim
- self.accumulation_steps = accumulation_steps
- num_update_steps_per_epoch = len(train_dataloader) // self.accumulation_steps
- max_steps = math.ceil(self.max_epochs * num_update_steps_per_epoch)
-
- self.scheduler = get_scheduler("cosine",
- self.optimizer,
- num_warmup_steps=math.ceil(max_steps * 0.03),
- num_training_steps=max_steps)
-
- def fit(self, logger, use_wandb: bool = False):
- if use_wandb:
- wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
- wandb.watch(self.model)
- total_loss = 0
- # epoch_bar = tqdm(range(self.epochs), desc='Epochs', disable=not is_rank_0())
- step_bar = tqdm(range(len(self.train_dataloader) // self.accumulation_steps * self.max_epochs),
- desc=f'steps',
- disable=not is_rank_0())
- for epoch in range(self.max_epochs):
-
- # process_bar = tqdm(range(len(self.train_dataloader)), desc=f'Train process for{epoch}', disable=not is_rank_0())
- # train
- self.model.train()
- for batch_id, batch in enumerate(self.train_dataloader):
-
- batch = to_device(batch, torch.cuda.current_device())
- outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
-
- loss = outputs.loss
-
- if loss >= 2.5 and is_rank_0():
- logger.warning(f"batch_id:{batch_id}, abnormal loss: {loss}")
-
- loss = loss / self.accumulation_steps
-
- self.strategy.backward(loss, self.model, self.optimizer)
-
- total_loss += loss.item()
-
- # gradient accumulation
- if (batch_id + 1) % self.accumulation_steps == 0:
- self.strategy.optimizer_step(self.optimizer)
- self.optimizer.zero_grad()
- self.scheduler.step()
- if is_rank_0() and use_wandb:
- wandb.log({
- "loss": total_loss / self.accumulation_steps,
- "lr": self.scheduler.get_last_lr()[0],
- "epoch": epoch,
- "batch_id": batch_id
- })
- total_loss = 0
- step_bar.update()
-
- # if batch_id % log_interval == 0:
- # logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}')
- # wandb.log({"loss": loss.item()})
-
- # process_bar.update()
-
- # eval
- if self.eval_dataloader is not None:
- self.model.eval()
- with torch.no_grad():
- loss_sum = 0
- num_seen = 0
- for batch in self.eval_dataloader:
- batch = to_device(batch, torch.cuda.current_device())
- outputs = self.model(batch["input_ids"],
- attention_mask=batch["attention_mask"],
- labels=batch["labels"])
- loss = outputs.loss
-
- loss_sum += loss.item()
- num_seen += batch["input_ids"].size(0)
-
- loss_mean = loss_sum / num_seen
- if dist.get_rank() == 0:
- logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}')
-
- # epoch_bar.update()
+ self.logger = logger
+ self.writer = None
+ if use_wandb and is_rank_0():
+ assert log_dir is not None, "log_dir must be provided when use_wandb is True"
+ import wandb
+
+ wandb.init(project="Coati-sft", sync_tensorboard=True)
+ if log_dir is not None and is_rank_0():
+ import os
+ import time
+
+ from torch.utils.tensorboard import SummaryWriter
+
+ log_dir = os.path.join(log_dir, "sft")
+ log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
+ self.writer = SummaryWriter(log_dir=log_dir)
+
+ self.total_loss = 0
diff --git a/applications/Chat/coati/trainer/strategies/__init__.py b/applications/Chat/coati/trainer/strategies/__init__.py
index f258c9b8a87324d28d03a22af94f485dcb416c24..521dcb5855b1a645e107ebecfea1ad093b08fb34 100644
--- a/applications/Chat/coati/trainer/strategies/__init__.py
+++ b/applications/Chat/coati/trainer/strategies/__init__.py
@@ -1,6 +1,5 @@
from .base import Strategy
-from .colossalai import ColossalAIStrategy
+from .colossalai import GeminiStrategy, LowLevelZeroStrategy
from .ddp import DDPStrategy
-from .naive import NaiveStrategy
-__all__ = ['Strategy', 'NaiveStrategy', 'DDPStrategy', 'ColossalAIStrategy']
+__all__ = ["Strategy", "DDPStrategy", "LowLevelZeroStrategy", "GeminiStrategy"]
diff --git a/applications/Chat/coati/trainer/strategies/base.py b/applications/Chat/coati/trainer/strategies/base.py
index b1452869179ebffdbbb05a42e64777ba3e20ca83..a78716216ae02a9b58ad3f314e45e1aeb432b72b 100644
--- a/applications/Chat/coati/trainer/strategies/base.py
+++ b/applications/Chat/coati/trainer/strategies/base.py
@@ -1,63 +1,63 @@
from abc import ABC, abstractmethod
from contextlib import nullcontext
-from typing import Any, List, Optional, Tuple, Union
+from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
-from coati.models.base import Actor, get_base_model
-from coati.replay_buffer import ReplayBuffer
+from coati.experience_buffer import ExperienceBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
+from colossalai.booster import Booster
+from colossalai.booster.plugin import Plugin
+
from .sampler import DistributedSampler
-ModelOptimPair = Tuple[nn.Module, Optimizer]
-ModelOrModelOptimPair = Union[nn.Module, ModelOptimPair]
+_BoostArgSpec = Union[nn.Module, Tuple[nn.Module, Optimizer], Dict]
class Strategy(ABC):
"""
- Base class for training strategies.
+ Base class for training strategies.
"""
- def __init__(self) -> None:
+ def __init__(self, plugin_initializer: Callable[..., Optional[Plugin]] = lambda: None) -> None:
super().__init__()
+ # NOTE: dist must be initialized before Booster
self.setup_distributed()
+ self.plugin = plugin_initializer()
+ self.booster = Booster(plugin=self.plugin)
+ self._post_init()
@abstractmethod
- def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
+ def _post_init(self) -> None:
pass
- @abstractmethod
+ def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
+ self.booster.backward(loss, optimizer)
+
def optimizer_step(self, optimizer: Optimizer, **kwargs) -> None:
- pass
+ optimizer.step()
@abstractmethod
def setup_distributed(self) -> None:
pass
@abstractmethod
- def setup_model(self, model: nn.Module) -> nn.Module:
- pass
-
- @abstractmethod
- def setup_optimizer(self, optimizer: Optimizer, model: nn.Module) -> Optimizer:
- pass
-
- @abstractmethod
- def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
+ def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader:
pass
def model_init_context(self):
return nullcontext()
- def prepare(
- self, *models_or_model_optim_pairs: ModelOrModelOptimPair
- ) -> Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]:
- """Prepare models or model-optimizer-pairs based on each strategy.
+ def prepare(self, *boost_args: _BoostArgSpec) -> Union[List[_BoostArgSpec], _BoostArgSpec]:
+ """Prepare [model | (model, optimizer) | Dict] based on each strategy.
+ NOTE: the keys of Dict must be a subset of `self.booster.boost`'s arguments.
Example::
+ >>> # e.g., include lr_scheduler
+ >>> result_dict = strategy.prepare(dict(model=model, lr_scheduler=lr_scheduler))
>>> # when fine-tuning actor and critic
>>> (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
>>> # or when training reward model
@@ -66,67 +66,72 @@ class Strategy(ABC):
>>> actor, critic = strategy.prepare(actor, critic)
Returns:
- Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order.
+ Union[List[_BoostArgSpec], _BoostArgSpec]: [model | (model, optimizer) | Dict] in the original order.
"""
- def prepare_model(model: nn.Module):
- if isinstance(model, Actor):
- return Actor(self.setup_model(model.get_base_model()))
- return self.setup_model(model)
-
rets = []
- for arg in models_or_model_optim_pairs:
- if isinstance(arg, tuple):
- assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"'
- model, optimizer = arg
- model = prepare_model(model)
- optimizer = self.setup_optimizer(optimizer, get_base_model(model))
+ for arg in boost_args:
+ if isinstance(arg, nn.Module):
+ model, *_ = self.booster.boost(arg)
+ rets.append(model)
+ elif isinstance(arg, tuple):
+ try:
+ model, optimizer = arg
+ except ValueError:
+ raise RuntimeError(f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"')
+ model, optimizer, *_ = self.booster.boost(model=model, optimizer=optimizer)
rets.append((model, optimizer))
- elif isinstance(arg, nn.Module):
- rets.append(prepare_model(arg))
+ elif isinstance(arg, Dict):
+ model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg)
+ boost_result = dict(
+ model=model,
+ optimizer=optimizer,
+ criterion=criterion,
+ dataloader=dataloader,
+ lr_scheduler=lr_scheduler,
+ )
+ # remove None values
+ boost_result = {key: value for key, value in boost_result.items() if value is not None}
+ rets.append(boost_result)
else:
- raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}')
+ raise RuntimeError(f"Type {type(arg)} is not supported")
- if len(rets) == 1:
- return rets[0]
- return rets
+ return rets[0] if len(rets) == 1 else rets
@staticmethod
def unwrap_model(model: nn.Module) -> nn.Module:
- """Get the unwrapped model from a wrapped model. Useful for getting original huggingface model.
- For Actor, it will unwrap `actor.model`.
+ """Get the unwrapped model from a wrapped model made by Strategy.prepare.
Args:
model (nn.Module): the model to unwrap
Returns:
- nn.Module: the original model (usually a huggingface model)
+ nn.Module: the original model
"""
- return get_base_model(model)
+ return model
- @abstractmethod
- def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
- pass
+ def save_model(self, model: nn.Module, path: str, shard: bool = False, **kwargs) -> None:
+ self.booster.save_model(model, path, shard=shard, **kwargs)
- @abstractmethod
- def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
- pass
+ def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None:
+ self.booster.load_model(model, path, strict)
- @abstractmethod
- def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
- pass
+ def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False, **kwargs) -> None:
+ self.booster.save_optimizer(optimizer, path, shard=not only_rank0, **kwargs)
- @abstractmethod
- def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
- pass
+ def load_optimizer(self, optimizer: Optimizer, path: str) -> None:
+ self.booster.load_optimizer(optimizer, path)
def setup_sampler(self, dataset) -> DistributedSampler:
+ # FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
return DistributedSampler(dataset, 1, 0)
@abstractmethod
- def save_pretrained(self,
- model: nn.Module,
- path: str,
- only_rank0: bool = True,
- tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
+ def save_pretrained(
+ self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None
+ ) -> None:
+ pass
+
+ @abstractmethod
+ def get_model_state_dict_shard(self, model: nn.Module, **config):
pass
diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py
index 8aa302c77eeec2efa5fb7e3e789b172e2f9578a2..7129edb060efff00724873151e1dc5f506da978c 100644
--- a/applications/Chat/coati/trainer/strategies/colossalai.py
+++ b/applications/Chat/coati/trainer/strategies/colossalai.py
@@ -1,47 +1,117 @@
import warnings
-from typing import Optional, Union
+from typing import Optional
-import torch
-import torch.distributed as dist
import torch.nn as nn
-import torch.optim as optim
-from coati.models.base import get_base_model
-from torch.optim import Optimizer
-from transformers.tokenization_utils_base import PreTrainedTokenizerBase
import colossalai
-from colossalai.logging import get_dist_logger
-from colossalai.nn.optimizer import CPUAdam, HybridAdam
-from colossalai.tensor import ProcessGroup, ShardSpec
+from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
+from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.utils import get_current_device
-from colossalai.zero import ColoInitContext, ZeroDDP, zero_model_wrapper, zero_optim_wrapper
+from colossalai.zero.gemini.gemini_ddp import GeminiDDP
from .ddp import DDPStrategy
-logger = get_dist_logger(__name__)
+class LowLevelZeroStrategy(DDPStrategy):
+ """
+ The strategy for training with ColossalAI.
+
+ Args:
+ stage(int): The stage to use in ZeRO. Choose in (1, 2)
+ precision(str): The precision to use. Choose in ('fp32', 'fp16').
+ seed(int): The seed for the random number generator.
+ placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
+ If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
+ If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
+ reduce_bucket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2.
+ overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2.
+ initial_scale(float): The initial scale for the optimizer.
+ growth_factor(float): The growth factor for the optimizer.
+ backoff_factor(float): The backoff factor for the optimizer.
+ growth_interval(int): The growth interval for the optimizer.
+ hysteresis(int): The hysteresis for the optimizer.
+ min_scale(float): The minimum scale for the optimizer.
+ max_scale(float): The maximum scale for the optimizer.
+ max_norm(float): The maximum norm for the optimizer.
+ norm_type(float): The norm type for the optimizer.
+
+ """
+
+ def __init__(
+ self,
+ stage: int = 2,
+ precision: str = "fp16",
+ seed: int = 42,
+ placement_policy: str = "cuda",
+ reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
+ overlap_communication: bool = True, # only for stage 1&2
+ initial_scale: float = 2**16,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ min_scale: float = 1,
+ max_scale: float = 2**32,
+ max_norm: float = 0.0,
+ norm_type: float = 2.0,
+ ) -> None:
+ assert stage in (1, 2), f'Unsupported stage "{stage}"'
+ assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"'
+ assert precision in ("fp32", "fp16"), f'Unsupported precision "{precision}"'
+
+ plugin_initializer = lambda: LowLevelZeroPlugin(
+ stage=stage,
+ precision=precision,
+ reduce_bucket_size_in_m=reduce_bucket_size,
+ overlap_communication=overlap_communication,
+ cpu_offload=(placement_policy == "cpu"),
+ initial_scale=initial_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ hysteresis=hysteresis,
+ min_scale=min_scale,
+ max_scale=max_scale,
+ max_norm=max_norm,
+ norm_type=norm_type,
+ )
+
+ super().__init__(seed, plugin_initializer)
+
+ def _post_init(self) -> None:
+ assert isinstance(
+ self.plugin, LowLevelZeroPlugin
+ ), f"{type(self).__name__}'s plugin is not initialized properly."
+
+ def setup_distributed(self) -> None:
+ colossalai.launch_from_torch({}, seed=self.seed)
+
+ def unwrap_model(self, model: nn.Module) -> nn.Module:
+ assert isinstance(model, LowLevelZeroModel)
+ return model.module
+
+ def get_model_state_dict_shard(self, model: nn.Module, **config):
+ assert isinstance(model, LowLevelZeroModel)
+ yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False)
-class ColossalAIStrategy(DDPStrategy):
+
+class GeminiStrategy(DDPStrategy):
"""
The strategy for training with ColossalAI.
Args:
- stage(int): The stage to use in ZeRO. Choose in (1, 2, 3)
- precision(str): The precision to use. Choose in ('fp32', 'fp16'). Stage 3 only supports fp16.
seed(int): The seed for the random number generator.
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
- This is not compativle with `from_pretrained()`. We temporarily disable this and will support it in the future.
+ This is not compatible with `from_pretrained()`. We temporarily disable this and will support it in the future.
placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
pin_memory(bool): Whether to pin the memory for the data loader. Only for ZeRO-3.
force_outputs_fp32(bool): Whether to force the outputs to be fp32. Only for ZeRO-3.
- search_range_mb(int): The search range in MB for the chunk size. Only for ZeRO-3.
+ search_range_m(int): The number of search range for the chunk size, divided by 2^20. Only for ZeRO-3.
hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3.
- min_chunk_size_mb(float): The minimum chunk size in MB. Only for ZeRO-3.
+ min_chunk_size_m(float): The minimum chunk size divided by 2^20. Only for ZeRO-3.
gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3.
- reduce_bugket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2.
- overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2.
initial_scale(float): The initial scale for the optimizer.
growth_factor(float): The growth factor for the optimizer.
backoff_factor(float): The backoff factor for the optimizer.
@@ -55,134 +125,76 @@ class ColossalAIStrategy(DDPStrategy):
"""
def __init__(
- self,
- stage: int = 3,
- precision: str = 'fp16',
- seed: int = 42,
- shard_init: bool = False, # only for stage 3
- placement_policy: str = 'cuda',
- pin_memory: bool = True, # only for stage 3
- force_outputs_fp32: bool = False, # only for stage 3
- scatter_after_inference: bool = False, # only for stage 3
- search_range_mb: int = 32, # only for stage 3
- hidden_dim: Optional[int] = None, # only for stage 3
- min_chunk_size_mb: float = 32, # only for stage 3
- gpu_margin_mem_ratio: float = 0.0, # only for stage 3
- reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
- overlap_communication: bool = True, # only for stage 1&2
- initial_scale: float = 2**16,
- growth_factor: float = 2,
- backoff_factor: float = 0.5,
- growth_interval: int = 1000,
- hysteresis: int = 2,
- min_scale: float = 1,
- max_scale: float = 2**32,
- max_norm: float = 0.0,
- norm_type: float = 2.0) -> None:
- super().__init__(seed)
- assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
- assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"'
- self.stage = stage
+ self,
+ seed: int = 42,
+ shard_init: bool = False, # only for stage 3
+ placement_policy: str = "auto",
+ shard_param_frac: float = 1.0, # only for static placement
+ offload_optim_frac: float = 0.0, # only for static placement
+ offload_param_frac: float = 0.0, # only for static placement
+ pin_memory: bool = True, # only for stage 3
+ force_outputs_fp32: bool = False, # only for stage 3
+ search_range_m: int = 32, # only for stage 3
+ hidden_dim: Optional[int] = None, # only for stage 3
+ min_chunk_size_m: float = 32, # only for stage 3
+ gpu_margin_mem_ratio: float = 0.0, # only for stage 3
+ initial_scale: float = 2**16,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ min_scale: float = 1,
+ max_scale: float = 2**32,
+ max_norm: float = 0.0,
+ norm_type: float = 2.0,
+ ) -> None:
# TODO(ver217): support shard_init when using from_pretrained()
if shard_init:
warnings.warn(
- f'Shard init is not supported model.from_pretrained() yet. Please load weights after strategy.prepare()'
+ f"Shard init is not supported model.from_pretrained() yet. "
+ "Please load weights after strategy.prepare()"
)
- if stage == 3 and precision == 'fp32':
- warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.')
- precision = 'fp16'
- self.precision = precision
self.shard_init = shard_init
- self.gemini_config = dict(device=get_current_device(),
- placement_policy=placement_policy,
- pin_memory=pin_memory,
- force_outputs_fp32=force_outputs_fp32,
- strict_ddp_mode=shard_init,
- search_range_mb=search_range_mb,
- hidden_dim=hidden_dim,
- min_chunk_size_mb=min_chunk_size_mb,
- scatter_after_inference=scatter_after_inference)
- if stage == 3:
- self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio)
- else:
- self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size,
- overlap_communication=overlap_communication,
- cpu_offload=(placement_policy == 'cpu'))
- self.optim_kwargs = dict(initial_scale=initial_scale,
- growth_factor=growth_factor,
- backoff_factor=backoff_factor,
- growth_interval=growth_interval,
- hysteresis=hysteresis,
- min_scale=min_scale,
- max_scale=max_scale,
- max_norm=max_norm,
- norm_type=norm_type)
+
+ warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.")
+
+ # NOTE: dist should be initialized before calling get_current_device()
+ plugin_initializer = lambda: GeminiPlugin(
+ chunk_init_device=get_current_device(),
+ placement_policy=placement_policy,
+ shard_param_frac=shard_param_frac,
+ offload_optim_frac=offload_optim_frac,
+ offload_param_frac=offload_param_frac,
+ precision="fp16",
+ pin_memory=pin_memory,
+ force_outputs_fp32=force_outputs_fp32,
+ strict_ddp_mode=shard_init,
+ search_range_m=search_range_m,
+ hidden_dim=hidden_dim,
+ min_chunk_size_m=min_chunk_size_m,
+ gpu_margin_mem_ratio=gpu_margin_mem_ratio,
+ initial_scale=initial_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ hysteresis=hysteresis,
+ min_scale=min_scale,
+ max_scale=max_scale,
+ max_norm=max_norm,
+ norm_type=norm_type,
+ )
+
+ super().__init__(seed, plugin_initializer)
+
+ def _post_init(self) -> None:
+ assert isinstance(self.plugin, GeminiPlugin), f"{type(self).__name__}'s plugin is not initialized properly."
def setup_distributed(self) -> None:
colossalai.launch_from_torch({}, seed=self.seed)
def model_init_context(self):
- if self.stage == 3:
- world_size = dist.get_world_size()
- shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
- default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
- return ColoInitContext(device=get_current_device(),
- dtype=torch.half,
- default_pg=shard_pg,
- default_dist_spec=default_dist_spec)
return super().model_init_context()
- def setup_model(self, model: nn.Module) -> nn.Module:
-
- model = zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config)
-
- if self.stage != 3 and self.precision == 'fp16':
- model = model.half().cuda()
- return model
-
- def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
- assert isinstance(optimizer, (CPUAdam, HybridAdam)), f'Unsupported optimizer {type(optimizer)}'
- return zero_optim_wrapper(model, optimizer, optim_config=self.zero_optim_config, **self.optim_kwargs)
-
- def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
- optimizer.backward(loss)
-
- def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
- optimizer.step()
-
- def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
- if only_rank0 and dist.get_rank() != 0 and self.stage != 3:
- return
- base_model = get_base_model(model)
- if self.stage == 3:
- assert isinstance(base_model, ZeroDDP)
- # for stage 3, state_dict() method should be called on every rank
- state_dict = base_model.state_dict(only_rank_0=only_rank0)
- else:
- # only_rank0 is false or rank == 0
- state_dict = base_model.state_dict()
- if only_rank0 and dist.get_rank() != 0:
- return
- torch.save(state_dict, path)
-
- def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
- if only_rank0:
- raise RuntimeError(
- f'Optimizer states are sharded when using ColossalAIStrategy. Only rank0 is not supported.')
- torch.save(optimizer.state_dict(), path)
-
def unwrap_model(self, model: nn.Module) -> nn.Module:
- base_model: Union[nn.Module, ZeroDDP] = get_base_model(model)
- if self.stage == 3:
- assert isinstance(base_model, ZeroDDP)
- return base_model.module
- return base_model
-
- def save_pretrained(self,
- model: nn.Module,
- path: str,
- only_rank0: bool = True,
- tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
- if self.stage == 3:
- raise RuntimeError('ColossalAI strategy with stage-3 does not support save_pretrained() now')
- super().save_pretrained(model, path, only_rank0, tokenizer)
+ assert isinstance(model, GeminiDDP)
+ return model.module
diff --git a/applications/Chat/coati/trainer/strategies/ddp.py b/applications/Chat/coati/trainer/strategies/ddp.py
index 7910b57878f86c5a8ed2781ab570e8302c3c2567..f2a44aeb096138a9c4848bb1b4531a199ed55e9c 100644
--- a/applications/Chat/coati/trainer/strategies/ddp.py
+++ b/applications/Chat/coati/trainer/strategies/ddp.py
@@ -1,93 +1,136 @@
import os
import random
-from typing import Optional
+from collections import OrderedDict
+from typing import Callable, Optional
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
-from coati.replay_buffer import ReplayBuffer
-from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.optim import Optimizer
+from coati.experience_buffer import ExperienceBuffer
+from coati.models import Actor, Critic, RewardModel
from torch.utils.data import DataLoader
+from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
-from .naive import NaiveStrategy
+from colossalai.booster.plugin import TorchDDPPlugin
+from colossalai.booster.plugin.torch_ddp_plugin import TorchDDPModel
+
+from .base import Strategy
from .sampler import DistributedSampler
-class DDPStrategy(NaiveStrategy):
+# TODO Move this to a util.py (Moving to ray.util introduces ringed import)
+def get_grad_required_state_dict(model: nn.Module):
+ state_dict = OrderedDict()
+ for name, parameter in model.named_parameters():
+ if parameter.requires_grad:
+ state_dict[name] = parameter.detach()
+ return state_dict
+
+
+class DDPStrategy(Strategy):
"""
- Strategy for distributed training using torch.distributed.
+ Strategy for distributed training using torch.distributed.
"""
- def __init__(self, seed: int = 42) -> None:
+ def __init__(self, seed: int = 42, plugin_initializer: Callable = TorchDDPPlugin) -> None:
self.seed = seed
- super().__init__()
+ super().__init__(plugin_initializer)
- def setup_distributed(self) -> None:
+ def _try_init_dist(self, force: bool = False) -> None:
try:
- rank = int(os.environ['RANK'])
- local_rank = int(os.environ['LOCAL_RANK'])
- world_size = int(os.environ['WORLD_SIZE'])
- host = os.environ['MASTER_ADDR']
- port = int(os.environ['MASTER_PORT'])
+ rank = int(os.environ["RANK"])
+ local_rank = int(os.environ["LOCAL_RANK"])
+ world_size = int(os.environ["WORLD_SIZE"])
+ host = os.environ["MASTER_ADDR"]
+ port = int(os.environ["MASTER_PORT"])
+ dist.init_process_group("nccl", init_method=f"tcp://[{host}]:{port}", world_size=world_size, rank=rank)
+ torch.cuda.set_device(local_rank)
except KeyError as e:
- raise RuntimeError(
- f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
- )
- dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank)
+ if force:
+ raise RuntimeError(
+ f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
+ )
+ except Exception as e:
+ if force:
+ raise e
+
+ def _post_init(self) -> None:
+ assert isinstance(self.plugin, TorchDDPPlugin), f"{type(self).__name__}'s plugin is not initialized properly."
+
+ def setup_distributed(self) -> None:
+ self._try_init_dist(force=True)
self.set_seed(self.seed)
- torch.cuda.set_device(local_rank)
def set_seed(self, seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
- def setup_model(self, model: nn.Module) -> nn.Module:
- device = torch.cuda.current_device()
- return DDP(model, device_ids=[device])
-
- def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
- # DDP only mode, replay buffers on each rank are different.
- # sampler = DistributedSampler(replay_buffer,
- # num_replicas=dist.get_world_size(),
- # rank=dist.get_rank(),
- # shuffle=True,
- # seed=self.seed,
- # drop_last=True)
- return DataLoader(
- replay_buffer,
- batch_size=replay_buffer.sample_batch_size,
- # sampler=sampler,
+ def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader:
+ return self.plugin.prepare_dataloader(
+ data_buffer,
+ batch_size=data_buffer.sample_batch_size,
shuffle=True,
drop_last=True,
pin_memory=pin_memory,
- collate_fn=replay_buffer.collate_fn)
-
- def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
- if only_rank0 and dist.get_rank() != 0:
- return
- super().save_model(model, path, only_rank0)
-
- def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
- if only_rank0 and dist.get_rank() != 0:
- return
- super().save_optimizer(optimizer, path, only_rank0)
+ collate_fn=data_buffer.collate_fn,
+ )
def setup_sampler(self, dataset) -> DistributedSampler:
+ # FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank())
def unwrap_model(self, model: nn.Module) -> nn.Module:
- base_model: DDP = super().unwrap_model(model)
- return base_model.module
-
- def save_pretrained(self,
- model: nn.Module,
- path: str,
- only_rank0: bool = True,
- tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
- if only_rank0 and dist.get_rank() != 0:
- return
- super().save_pretrained(model, path, only_rank0, tokenizer)
+ assert isinstance(model, TorchDDPModel), "model is not wrapped by TorchDDPModel."
+ return model.unwrap()
+
+ def save_pretrained(
+ self, model: nn.Module, path: str, shard: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None
+ ) -> None:
+ if dist.get_rank() == 0:
+ unwrapped_model = self.unwrap_model(model)
+ assert isinstance(unwrapped_model, (Actor, Critic, RewardModel))
+ pretrained_model = unwrapped_model.model
+ assert isinstance(pretrained_model, PreTrainedModel)
+ # HACK: only use hf save_pretrained to save config
+ pretrained_model.save_pretrained(path, save_function=lambda *args, **kwargs: None)
+ if tokenizer is not None:
+ tokenizer.save_pretrained(path)
+
+ model_path = os.path.join(path, "pytorch_model.bin")
+ self.save_model(model, model_path, shard=shard)
+ def _replace_keys(model_path: str, replace_fn: Callable):
+ state_dict = torch.load(model_path, map_location="cpu")
+ state_dict = {replace_fn(k): v for k, v in state_dict.items()}
+ torch.save(state_dict, model_path)
+ # FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
+ # HACK: rename keys of pytorch_model.bin
+ if dist.get_rank() == 0:
+ _replace_keys(model_path, lambda k: k.replace("model.", "", 1))
+
+
+ def get_model_state_dict_shard(self, model: nn.Module, **config):
+ # TODO: implement sharding on naive strategy
+ model = self.unwrap_model(model)
+ if "requires_grad_only" in config and config["requires_grad_only"] == True:
+ state_dict = get_grad_required_state_dict(model)
+ else:
+ state_dict = model.state_dict()
+
+ if "shard_size" in config:
+ shard_size = config["shard_size"]
+ accumulate_size = 0
+ state_dict_shard = OrderedDict()
+ for name, param in state_dict.items():
+ state_dict_shard[name] = param
+ accumulate_size += param.numel() * param.element_size()
+ if accumulate_size >= shard_size:
+ accumulate_size = 0
+ yield state_dict_shard
+ state_dict_shard = OrderedDict()
+ if accumulate_size > 0:
+ yield state_dict_shard
+ else:
+ yield state_dict
diff --git a/applications/Chat/coati/trainer/strategies/naive.py b/applications/Chat/coati/trainer/strategies/naive.py
deleted file mode 100644
index 4d94026ce9320ef754632680d5b44de6baae4862..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/trainer/strategies/naive.py
+++ /dev/null
@@ -1,70 +0,0 @@
-from typing import Any, Optional
-
-import torch
-import torch.nn as nn
-import torch.optim as optim
-from coati.models.base import get_base_model
-from coati.replay_buffer import ReplayBuffer
-from torch.optim import Optimizer
-from torch.utils.data import DataLoader
-from transformers.modeling_utils import PreTrainedModel
-from transformers.tokenization_utils_base import PreTrainedTokenizerBase
-
-from .base import Strategy
-
-
-class NaiveStrategy(Strategy):
- """
- Strategy for single GPU. No parallelism is used.
- """
-
- def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
- loss.backward()
-
- def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
- optimizer.step()
-
- def setup_distributed(self) -> None:
- pass
-
- def setup_model(self, model: nn.Module) -> nn.Module:
- return model
-
- def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
- return optimizer
-
- def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
- return DataLoader(replay_buffer,
- batch_size=replay_buffer.sample_batch_size,
- shuffle=True,
- drop_last=True,
- pin_memory=pin_memory,
- collate_fn=replay_buffer.collate_fn)
-
- def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
- base_model = get_base_model(model)
- state_dict = base_model.state_dict()
- torch.save(state_dict, path)
-
- def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
- base_model = get_base_model(model)
- state_dict = torch.load(path, map_location=map_location)
- base_model.load_state_dict(state_dict, strict=strict)
-
- def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
- torch.save(optimizer.state_dict(), path)
-
- def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
- state_dict = torch.load(path, map_location=map_location)
- optimizer.load_state_dict(state_dict)
-
- def save_pretrained(self,
- model: nn.Module,
- path: str,
- only_rank0: bool = True,
- tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
- unwrapped_model = self.unwrap_model(model)
- assert isinstance(unwrapped_model, PreTrainedModel)
- unwrapped_model.save_pretrained(path)
- if tokenizer is not None:
- tokenizer.save_pretrained(path)
diff --git a/applications/Chat/coati/trainer/strategies/sampler.py b/applications/Chat/coati/trainer/strategies/sampler.py
index d726fa640fa201b8bdec5c7601cc2895c4357316..6e811bef11a59303ea9873d8df3ebea331576a54 100644
--- a/applications/Chat/coati/trainer/strategies/sampler.py
+++ b/applications/Chat/coati/trainer/strategies/sampler.py
@@ -4,7 +4,6 @@ import numpy as np
class DistributedSampler:
-
def __init__(self, dataset, num_replicas: int, rank: int) -> None:
self.dataset = dataset
self.num_replicas = num_replicas
@@ -12,7 +11,7 @@ class DistributedSampler:
if len(self.dataset) % self.num_replicas != 0:
self.num_samples = math.ceil(
- (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
+ (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
@@ -20,10 +19,10 @@ class DistributedSampler:
self.total_size = self.num_samples * self.num_replicas
indices = list(range(len(self.dataset)))
- indices = indices[:self.total_size]
+ indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
- indices = indices[self.rank:self.total_size:self.num_replicas]
+ indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
self.indices = indices
diff --git a/applications/Chat/coati/trainer/utils.py b/applications/Chat/coati/trainer/utils.py
index 9cccb5c9260395859611c55207f7dbaf817ef76f..7811e7365eeb08d3fcefc1161075773ff3d18683 100644
--- a/applications/Chat/coati/trainer/utils.py
+++ b/applications/Chat/coati/trainer/utils.py
@@ -3,6 +3,38 @@ from typing import Any
import torch
import torch.distributed as dist
from torch.utils._pytree import tree_map
+from torch.utils.data import DataLoader
+
+
+class CycledDataLoader:
+ """
+ Why do we need this class?
+ In version 4da324cd60, "prompts = next(iter(self.prompt_dataloader))" is used to sample a batch of prompts/pretrain.
+ However, this may be inefficient due to frequent re-initialization of the dataloader. (re-initialize workers...)
+ NOTE: next(iter(dataloader)) is not equivalent to for batch in dataloader: break, it causes slightly different behavior.
+ """
+
+ def __init__(
+ self,
+ dataloader: DataLoader,
+ ) -> None:
+ self.dataloader = dataloader
+
+ self.count = 0
+ self.dataloader_iter = None
+
+ def next(self):
+ # defer initialization
+ if self.dataloader_iter is None:
+ self.dataloader_iter = iter(self.dataloader)
+
+ self.count += 1
+ try:
+ return next(self.dataloader_iter)
+ except StopIteration:
+ self.count = 0
+ self.dataloader_iter = iter(self.dataloader)
+ return next(self.dataloader_iter)
def is_rank_0() -> bool:
@@ -10,7 +42,6 @@ def is_rank_0() -> bool:
def to_device(x: Any, device: torch.device) -> Any:
-
def _to(t: Any):
if isinstance(t, torch.Tensor):
return t.to(device)
diff --git a/applications/Chat/coati/utils/__init__.py b/applications/Chat/coati/utils/__init__.py
deleted file mode 100644
index 112b82b9706444013013777d12798b4f6b62e52a..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/utils/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .tokenizer_utils import prepare_llama_tokenizer_and_embedding, smart_tokenizer_and_embedding_resize
-
-__all__ = ['smart_tokenizer_and_embedding_resize', 'prepare_llama_tokenizer_and_embedding']
\ No newline at end of file
diff --git a/applications/Chat/coati/utils/tokenizer_utils.py b/applications/Chat/coati/utils/tokenizer_utils.py
deleted file mode 100644
index e0d96cfc8be2711397d23bcfed725bd0ba10a2bf..0000000000000000000000000000000000000000
--- a/applications/Chat/coati/utils/tokenizer_utils.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Dict
-
-import transformers
-
-DEFAULT_PAD_TOKEN = "[PAD]"
-DEFAULT_EOS_TOKEN = ""
-DEFAULT_BOS_TOKEN = ""
-DEFAULT_UNK_TOKEN = ""
-
-
-def prepare_llama_tokenizer_and_embedding(
- tokenizer: transformers.PreTrainedTokenizer,
- model: transformers.PreTrainedModel,
- special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN),
-):
- """prepare llama tokenizer and embedding.
-
- """
-
- if tokenizer.pad_token is None:
- smart_tokenizer_and_embedding_resize(
- special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
- tokenizer=tokenizer,
- model=model,
- )
-
- tokenizer.add_special_tokens({
- "eos_token": DEFAULT_EOS_TOKEN,
- "bos_token": DEFAULT_BOS_TOKEN,
- "unk_token": DEFAULT_UNK_TOKEN,
- })
-
- return tokenizer
-
-
-def smart_tokenizer_and_embedding_resize(
- tokenizer: transformers.PreTrainedTokenizer,
- model: transformers.PreTrainedModel,
- special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN),
-):
- """Resize tokenizer and embedding.
-
- Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
- """
-
- if tokenizer.pad_token is None:
- num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
-
- model.resize_token_embeddings(len(tokenizer))
-
- if num_new_tokens > 0:
- input_embeddings = model.get_input_embeddings().weight.data
- output_embeddings = model.get_output_embeddings().weight.data
-
- input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
- output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
-
- input_embeddings[-num_new_tokens:] = input_embeddings_avg
- output_embeddings[-num_new_tokens:] = output_embeddings_avg
diff --git a/applications/Chat/evaluate/README.md b/applications/Chat/evaluate/README.md
deleted file mode 100644
index 7ace4bfe6d1871bb97e3e522dc35128d692a7cb7..0000000000000000000000000000000000000000
--- a/applications/Chat/evaluate/README.md
+++ /dev/null
@@ -1,182 +0,0 @@
-# Evaluation
-
-In this directory, we introduce how you can evaluate your model with GPT-4.
-
-## Evaluation Pipeline
-
-The whole evaluation process undergoes the following three steps:
-1. Prepare the questions following the internal data structure in the data format section (described below).
-2. Generate answers from different models:
- * Generate answers using GPT-3.5: [`generate_gpt35_answers.py`](generate_gpt35_answers.py).
- * Generate answers using your own models: [`generate_answers.py`](generate_answers.py).
-3. Evaluate models using GPT-4: [`evaluate.py`](evaluate.py).
-
-### Generate Answers
-#### Generate Answers Using GPT-3.5
-You can provide your own OpenAI key to generate answers from GPT-3.5 using [`generate_gpt35_answers.py`](./generate_gpt35_answers.py).
-
-An example script is provided as follows:
-```shell
-python generate_gpt35_answers.py \
- --dataset "path to the question dataset" \
- --answer_path "path to answer folder" \
- --num_workers 4 \
- --openai_key "your openai key" \
- --max_tokens 512 \
-```
-
-#### Generate Answers Using our Own Model
-You can also generate answers using your own models. The generation process is divided into two stages:
-1. Generate answers using multiple GPUs (optional) with batch processing: [`generate_answers.py`](./generate_answers.py).
-2. Merge multiple shards and output a single file: [`merge.py`](./merge.py).
-
-An example script is given as follows:
-
-```shell
-device_number=number of your devices
-model_name="name of your model"
-model_path="path to your model"
-dataset="path to the question dataset"
-answer_path="path to save the model answers"
-
-torchrun --standalone --nproc_per_node=$device_number generate_answers.py \
- --model 'llama' \
- --strategy ddp \
- --model_path $model_path \
- --model_name $model_name \
- --dataset $dataset \
- --batch_size 8 \
- --max_datasets_size 80 \
- --answer_path $answer_path \
- --max_length 512
-
-python merge.py \
- --model_name $model_name \
- --shards $device_number \
- --answer_path $answer_path \
-
-for (( i=0; i scores[1]:
- worse_count += 1
- worse_file.append(review_jsons[idx])
- elif scores[0] < scores[1]:
- better_count += 1
- better_file.append(review_jsons[idx])
- else:
- tie_count += 1
- tie_file.append(review_jsons[idx])
- ans1_score += scores[0]
- ans2_score += scores[1]
-
- output_review_file.append(review_jsons[idx])
-
- better_file.sort(key=lambda x: x['id'])
- worse_file.sort(key=lambda x: x['id'])
- tie_file.sort(key=lambda x: x['id'])
- invalid_file.sort(key=lambda x: x['id'])
- output_review_file.sort(key=lambda x: x['id'])
-
- name1 = os.path.basename(args.answer_file_list[0]).split("_answers")[0]
- name2 = os.path.basename(args.answer_file_list[1]).split("_answers")[0]
- prefix = f"{name1}_vs_{name2}"
-
- jdump(better_file, os.path.join(
- args.output_folder, prefix, f"{prefix}_better.json"))
- jdump(worse_file, os.path.join(
- args.output_folder, prefix, f"{prefix}_worse.json"))
- jdump(tie_file, os.path.join(
- args.output_folder, prefix, f"{prefix}_tie.json"))
- jdump(invalid_file, os.path.join(
- args.output_folder, prefix, f"{prefix}_invalid.json"))
- jdump(output_review_file, os.path.join(
- args.output_folder, prefix, f"{prefix}_review.json"))
-
- if os.path.exists(os.path.join(args.output_folder, "results.json")):
- results = jload(os.path.join(args.output_folder, "results.json"))
- else:
- results = {}
- results[prefix] = {'model': [name1, name2], 'better': better_count, 'worse': worse_count, 'tie': tie_count, 'win_rate': better_count /
- (len(reviews)-invalid_count), 'score': [ans1_score/(len(reviews)-invalid_count), ans2_score/(len(reviews)-invalid_count)]}
- jdump(results, os.path.join(args.output_folder, "results.json"))
-
- logger.info(f' Total {invalid_count} invalid score pair(s).')
- logger.info(f' Model {name2} has {better_count} better answer(s).')
- logger.info(f' Model {name2} has {worse_count} worse answer(s).')
- logger.info(f' {tie_count} answer(s) play(s) to a tie.')
- logger.info(
- f' Win rate of model {name2}: {better_count/(len(reviews)-invalid_count):.2f}')
- logger.info(
- f' Model {name1} average score: {ans1_score/(len(reviews)-invalid_count):.2f}')
- logger.info(
- f' Model {name2} average score: {ans2_score/(len(reviews)-invalid_count):.2f}')
-
-
-if __name__ == '__main__':
- parser = argparse.ArgumentParser(
- description='Model evaluation.')
- parser.add_argument('--answer_file_list', nargs='+', default=[])
- parser.add_argument('--prompt_file')
- parser.add_argument('--reviewer_file')
- parser.add_argument('--output_folder', type=str, default="./output")
- parser.add_argument('--openai_key', type=str, default=None)
- parser.add_argument('--model', type=str, default="gpt-4")
- parser.add_argument('--num_workers', type=int, default=8)
- parser.add_argument('--max_tokens', type=int, default=512,
- help='maximum number of tokens produced in the output')
- args = parser.parse_args()
-
- if args.openai_key is not None:
- os.environ["OPENAI_API_KEY"] = args.openai_key
- openai.api_key = os.getenv("OPENAI_API_KEY")
-
- evaluate(args)
diff --git a/applications/Chat/evaluate/evaluate.sh b/applications/Chat/evaluate/evaluate.sh
deleted file mode 100755
index c51aa941019e55e38f57a459f88dc6ae85264c0e..0000000000000000000000000000000000000000
--- a/applications/Chat/evaluate/evaluate.sh
+++ /dev/null
@@ -1,9 +0,0 @@
-python evaluate.py \
- --answer_file_list "path to answers of model 1" "path to answers of model 2" \
- --prompt_file "path to prompt file" \
- --reviewer_file "path to reviewer file" \
- --output_folder "path to output folder" \
- --openai_key "your openai key" \
- --model "gpt-4" \
- --num_workers 8 \
- --max_tokens 512 \
diff --git a/applications/Chat/evaluate/generate_answers.py b/applications/Chat/evaluate/generate_answers.py
deleted file mode 100644
index fbebf5c5e6f6835575ba6d9733604924d9e91deb..0000000000000000000000000000000000000000
--- a/applications/Chat/evaluate/generate_answers.py
+++ /dev/null
@@ -1,173 +0,0 @@
-import argparse
-import os
-import random
-import copy
-import math
-from tqdm import tqdm
-
-import torch
-import torch.distributed as dist
-import transformers
-
-from coati.models.bloom import BLOOMActor
-from coati.models.gpt import GPTActor
-from coati.models.opt import OPTActor
-from coati.models.roberta import RoBERTaActor
-from coati.models.llama import LlamaActor
-from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
-from transformers import AutoTokenizer, RobertaTokenizer
-from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
-
-from colossalai.logging import get_dist_logger
-
-from utils import jload, jdump, is_rank_0
-
-
-logger = get_dist_logger()
-
-PROMPT_DICT = {
- "prompt_input":
- ("Below is an instruction that describes a task, paired with an input that provides further context. "
- "Write a response that appropriately completes the request.\n\n"
- "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"),
- "prompt_no_input": ("Below is an instruction that describes a task. "
- "Write a response that appropriately completes the request.\n\n"
- "### Instruction:\n{instruction}\n\n### Response:"),
-}
-
-
-def generate(args):
- # torch.cuda.set_per_process_memory_fraction(0.4)
- if args.strategy == 'naive':
- strategy = NaiveStrategy()
- elif args.strategy == 'ddp':
- strategy = DDPStrategy()
- elif args.strategy == 'colossalai_gemini':
- strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
- elif args.strategy == 'colossalai_zero2':
- strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
- elif args.strategy == 'colossalai_zero2_cpu':
- strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
- else:
- raise ValueError(f'Unsupported strategy "{args.strategy}"')
-
- world_size = dist.get_world_size()
- rank = dist.get_rank()
-
- with strategy.model_init_context():
- if args.model == 'gpt2':
- actor = GPTActor(pretrained=args.model_path).to(
- torch.cuda.current_device())
- elif args.model == 'bloom':
- actor = BLOOMActor(pretrained=args.model_path).to(
- torch.cuda.current_device())
- elif args.model == 'opt':
- actor = OPTActor(pretrained=args.model_path).to(
- torch.cuda.current_device())
- elif args.model == 'roberta':
- actor = RoBERTaActor(pretrained=args.model_path).to(
- torch.cuda.current_device())
- elif args.model == 'llama':
- actor = LlamaActor(pretrained=args.model_path).to(
- torch.float16).to(torch.cuda.current_device())
- else:
- raise ValueError(f'Unsupported model "{args.model}"')
-
- if args.model == 'gpt2':
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
- tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'bloom':
- tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')
- tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'opt':
- tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
- elif args.model == 'roberta':
- tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
- elif args.model == 'llama':
- tokenizer = AutoTokenizer.from_pretrained(args.model_path,
- padding_side="right",
- use_fast=False,
- )
- tokenizer.eos_token = '<\s>'
- else:
- raise ValueError(f'Unsupported model "{args.model}"')
-
- questions = []
- if args.max_datasets_size is not None:
- questions = random.sample(jload(args.dataset), args.max_datasets_size)
- if is_rank_0():
- logger.info(
- f"Limiting dataset to {args.max_datasets_size} examples.")
- questions = questions[rank:args.max_datasets_size:world_size]
-
- answers = copy.deepcopy(questions)
-
- prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
- sources = [
- prompt_input.format_map(example) if example.get(
- "input", "") != "" else prompt_no_input.format_map(example)
- for example in questions
- ]
-
- if is_rank_0():
- logger.info("Tokenizing inputs... This may take some time...")
-
- input_ids_list = []
-
- for string in sources:
- input_ids = tokenizer.encode(string, return_tensors='pt').squeeze(0)
- input_ids_list.append(input_ids)
-
- bar = tqdm(range(math.ceil(len(input_ids_list)/args.batch_size)),
- desc=f'steps', disable=not is_rank_0())
-
- actor.eval()
- with torch.no_grad():
- for i in range(0, len(input_ids_list), args.batch_size):
- batch = input_ids_list[i:i+args.batch_size]
- batch = [i.flip(dims=[0]) for i in batch]
- batch = torch.nn.utils.rnn.pad_sequence(batch,
- batch_first=True,
- padding_value=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0).to(torch.cuda.current_device())
- batch = batch.flip(dims=[1])
- attention_mask = batch.ne(tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0)
-
- outputs = actor.model.generate(batch, attention_mask=attention_mask,
- max_length=args.max_length,
- do_sample=True,
- top_k=50,
- top_p=0.95,
- num_return_sequences=1)
-
- outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
- for j in range(batch.size(0)):
- answers[i +
- j]['output'] = outputs[j].split("### Response:")[1].strip()
-
- bar.update()
-
- jdump(answers, os.path.join(args.answer_path,
- f'{args.model_name}_answers_rank{rank}.json'))
-
- if is_rank_0():
- logger.info(
- f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB')
-
-
-if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--strategy',
- choices=['naive', 'ddp', 'colossalai_gemini',
- 'colossalai_zero2', 'colossalai_zero2_cpu'],
- default='naive')
- parser.add_argument('--model', default='gpt2',
- choices=['gpt2', 'bloom', 'opt', 'roberta', 'llama'])
- parser.add_argument('--model_path', type=str, default=None)
- parser.add_argument('--model_name', type=str, default='model')
- parser.add_argument('--dataset', type=str, default=None)
- parser.add_argument('--batch_size', type=int, default=1)
- parser.add_argument('--max_datasets_size', type=int, default=None)
- parser.add_argument('--answer_path', type=str, default="answer")
- parser.add_argument('--max_length', type=int, default=1024)
- args = parser.parse_args()
- generate(args)
diff --git a/applications/Chat/evaluate/generate_answers.sh b/applications/Chat/evaluate/generate_answers.sh
deleted file mode 100755
index 36881f5f4f292885a153e4558eb1240823c14bb5..0000000000000000000000000000000000000000
--- a/applications/Chat/evaluate/generate_answers.sh
+++ /dev/null
@@ -1,25 +0,0 @@
-device_number=number of your devices
-model_name="name of your model"
-model_path="path to your model"
-dataset="path to the question dataset"
-answer_path="path to save the model answers"
-
-torchrun --standalone --nproc_per_node=$device_number generate_answers.py \
- --model 'llama' \
- --strategy ddp \
- --model_path $model_path \
- --model_name $model_name \
- --dataset $dataset \
- --batch_size 8 \
- --max_datasets_size 80 \
- --answer_path $answer_path \
- --max_length 512
-
-python merge.py \
- --model_name $model_name \
- --shards $device_number \
- --answer_path $answer_path \
-
-for (( i=0; i bool:
- return not dist.is_initialized() or dist.get_rank() == 0
-
-def _make_w_io_base(f, mode: str):
- if not isinstance(f, io.IOBase):
- f_dirname = os.path.dirname(f)
- if f_dirname != "":
- os.makedirs(f_dirname, exist_ok=True)
- f = open(f, mode=mode)
- return f
-
-def _make_r_io_base(f, mode: str):
- if not isinstance(f, io.IOBase):
- f = open(f, mode=mode)
- return f
-
-def jdump(obj, f, mode="w", indent=4, default=str):
- """Dump a str or dictionary to a file in json format.
- Args:
- obj: An object to be written.
- f: A string path to the location on disk.
- mode: Mode for opening the file.
- indent: Indent for storing json dictionaries.
- default: A function to handle non-serializable entries; defaults to `str`.
- """
- f = _make_w_io_base(f, mode)
- if isinstance(obj, (dict, list)):
- json.dump(obj, f, indent=indent, default=default)
- elif isinstance(obj, str):
- f.write(obj)
- else:
- raise ValueError(f"Unexpected type: {type(obj)}")
- f.close()
-
-def jload(f, mode="r"):
- """Load a .json file into a dictionary."""
- f = _make_r_io_base(f, mode)
- jdict = json.load(f)
- f.close()
- return jdict
-
-def get_json_list(file_path):
- with open(file_path, 'r') as f:
- json_list = []
- for line in f:
- json_list.append(json.loads(line))
- return json_list
diff --git a/applications/Chat/examples/README.md b/applications/Chat/examples/README.md
index 561ace2205ed559846560e9352a13852eddd44bd..9438aafd126811cdfd967334be6168637db4f941 100644
--- a/applications/Chat/examples/README.md
+++ b/applications/Chat/examples/README.md
@@ -6,6 +6,7 @@
- [Table of Contents](#table-of-contents)
- [Install requirements](#install-requirements)
- [Supervised datasets collection](#supervised-datasets-collection)
+ - [Conversation dataset generation](#conversation-dataset-generation)
- [Stage1 - Supervised instructs tuning](#stage1---supervised-instructs-tuning)
- [Arg List](#arg-list)
- [Stage2 - Training reward model](#stage2---training-reward-model)
@@ -16,7 +17,7 @@
- [Arg List](#arg-list-2)
- [Inference example - After Stage3](#inference-example---after-stage3)
- [Attention](#attention)
- - [data](#data)
+ - [data](#data)
- [Support Model](#support-model)
- [GPT](#gpt)
- [BLOOM](#bloom)
@@ -24,12 +25,11 @@
- [LLaMA](#llama)
- [Add your own models](#add-your-own-models)
- [Actor model](#actor-model)
- - [LM model](#lm-model)
- [Reward model](#reward-model)
- [Critic model](#critic-model)
-
---
+
## Install requirements
```shell
@@ -38,27 +38,74 @@ pip install -r requirements.txt
## Supervised datasets collection
-We collected 104K bilingual dataset of Chinese and English, and you can find the datasets in this repo
-[InstructionWild](https://github.com/XueFuzhao/InstructionWild).
+We collected 104K bilingual datasets of Chinese and English, and you can find the datasets in this repo
+[InstructionWild](https://github.com/XueFuzhao/InstructionWild) and in this [file](https://github.com/XueFuzhao/InstructionWild/blob/main/data/README.md).
+
+Here is how we collected the data
-The following pic shows how we collected the data.
+### Conversation dataset generation
+
+In order to further improve the model's ability to handle multi-turn conversations, we need to include samples with multi-turn conversations in the dataset. However, the samples in InstructWild and Alpaca datasets currently consist of only single-turn conversations, and their dataset organization is not suitable for storing multi-turn conversations. Additionally, after converting the aforementioned datasets, we also need to include multi-turn conversation datasets like ShareGPT, and we should transform them into the training format supported by ColossalChat.
+
+A sample of conversation dataset should have the following fields:
+
+- `type` (str, optional): The type of the data sample.
+- `language` (str, optional): The language of the data sample.
+- `dataset` (str, optional): The dataset the data sample originates from.
+- `conversations` (str, compulsory): Conversation content of the data sample.
+- `id` (int, optional): The ID of the data sample.
+
+A simple example:
+
+```json
+{
+ "type": "instruction",
+ "language": "English",
+ "dataset": "Alpaca",
+ "conversations": [
+ {
+ "from": "human",
+ "value": "Give three tips for staying healthy."
+ },
+ {
+ "from": "gpt",
+ "value": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."
+ }
+ ],
+ "id": 1
+}
+```
+
+> **NOTE:** Only key `conversations` is compulsary for training and other keys serve as metadata. The length of `conversations` varies.
+
+You can run the `examples/generate_conversation_dataset.py` to generate a conversation dataset supported by ColossalChat.
+
+You can use the following cmd to generate conversation dataset.
+
+```bash
+python generate_conversation_dataset.py \
+ --dataset "All"
+ --save_path "/path/to/dataset"
+```
+
## Stage1 - Supervised instructs tuning
Stage1 is supervised instructs fine-tuning, which uses the datasets mentioned earlier to fine-tune the model.
+[[Stage1 tutorial video]](https://www.youtube.com/watch?v=-qFBZFmOJfg)
You can run the `examples/train_sft.sh` to start a supervised instructs fine-tuning.
You can also use the following cmd to start a supervised instructs fine-tuning with your own settings.
-```
+
+```bash
torchrun --standalone --nproc_per_node=4 train_sft.py \
--pretrain "/path/to/LLaMa-7B/" \
--model 'llama' \
--strategy colossalai_zero2 \
- --log_interval 10 \
--save_path /path/to/Coati-7B \
--dataset /path/to/data.json \
--batch_size 4 \
@@ -68,27 +115,44 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \
--max_epochs 1 \
--grad_checkpoint
```
+
+**Note**: the supervised dataset follows the following format,
+
+```json
+[
+ {
+ "instruction": "Provide a list of the top 10 most popular mobile games in Asia",
+ "input": "",
+ "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
+ "id": 0
+ },
+ ...
+]
+```
+
### Arg List
-- --strategy: the strategy using for training, choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2'
-- --model: model type, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom'
-- --pretrain: pretrain model, type=str, default=None
-- --max_datasets_size: the max size of dataset, type=int, default=None
-- --save_path: path to save the model, type=str, default='output'
-- --need_optim_ckpt: whether to save optim ckpt, type=bool, default=False
-- --max_epochs: max epochs for training, type=int, default=3
-- --batch_size: batch size while training, type=int, default=4
-- --lora_rank: low-rank adaptation matrices rank, type=int, default=0
-- --log_interval: how many steps to log, type=int, default=100
-- --grad_checkpoint: enable gradient checkpointing, type=bool, default=False
+
+- `--strategy`: the strategy using for training, choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2'
+- `--model`: model type, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom'
+- `--pretrain`: pretrain model, type=str, default=None
+- `--max_datasets_size`: the max size of dataset, type=int, default=None
+- `--save_path`: path to save the model, type=str, default='output'
+- `--need_optim_ckpt`: whether to save optim ckpt, type=bool, default=False
+- `--max_epochs`: max epochs for training, type=int, default=3
+- `--batch_size`: batch size while training, type=int, default=4
+- `--lora_rank`: low-rank adaptation matrices rank, type=int, default=0
+- `--grad_checkpoint`: enable gradient checkpointing, type=bool, default=False
## Stage2 - Training reward model
We train a reward model in stage 2, which obtains corresponding scores by manually ranking different outputs for the same prompt and supervises the training of the reward model.
+[[Stage2 tutorial video]](https://www.youtube.com/watch?v=gMx2CApKhuo)
You can run the `examples/train_rm.sh` to start a reward model training.
You can also use the following cmd to start training a reward model.
-```
+
+```bash
torchrun --standalone --nproc_per_node=4 train_reward_model.py \
--pretrain "/path/to/LLaMa-7B/" \
--model 'llama' \
@@ -96,16 +160,19 @@ torchrun --standalone --nproc_per_node=4 train_reward_model.py \
--loss_fn 'log_exp'\
--save_path 'rmstatic.pt' \
```
+
### Features and tricks in RM training
+
- We support [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf)and[rm-static](https://huggingface.co/datasets/Dahoas/rm-static) datasets.
-- We support 2 kinds of loss_function named 'log_sig'(used by OpenAI) and 'log_exp'(used by Anthropic).
-- We change the loss to valid_acc and pair_dist to monitor progress during training.
+- We support 2 kinds of loss function named `log_sig`(used by OpenAI) and `log_exp`(used by Anthropic).
+- We change the loss to `valid_acc` and `pair_dist` to monitor progress during training.
- We add special token to the end of the sequence to get better result.
- We use cosine-reducing lr-scheduler for RM training.
- We set value_head as 1 liner layer and initialize the weight of value_head using N(0,1/(d_model + 1)) distribution.
- We train a Bloom-560m reward model for 1 epoch and find the test acc of the model achieve the performance mentions in [Anthropics paper](https://arxiv.org/abs/2204.05862).
### Experiment result
+
Model performance in [Anthropics paper](https://arxiv.org/abs/2204.05862):

@@ -117,20 +184,20 @@ Model performance in [Anthropics paper](https://arxiv.org/abs/2204.05862):
We also train the reward model based on LLaMA-7B, which reaches the ACC of 72.06% after 1 epoch, performing almost the same as Anthropic's best RM.
### Arg List
-- --strategy: the strategy using for training, choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2'
-- --model: model type, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom'
-- --pretrain: pretrain model, type=str, default=None
-- --model_path: the path of rm model(if continue to train), type=str, default=None
-- --save_path: path to save the model, type=str, default='output'
-- --need_optim_ckpt: whether to save optim ckpt, type=bool, default=False
-- --max_epochs: max epochs for training, type=int, default=3
-- --dataset: dataset name, type=str, choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static']
-- --subset: subset of the dataset, type=str, default=None
-- --batch_size: batch size while training, type=int, default=4
-- --lora_rank: low-rank adaptation matrices rank, type=int, default=0
-- --loss_func: which kind of loss function, choices=['log_sig', 'log_exp']
-- --max_len: max sentence length for generation, type=int, default=512
-- --test: whether is only testing, if it's true, the dataset will be small
+
+- `--strategy`: the strategy using for training, choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2'
+- `--model`: model type, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom'
+- `--pretrain`: pretrain model, type=str, default=None
+- `--model_path`: the path of rm model(if continue to train), type=str, default=None
+- `--save_path`: path to save the model, type=str, default='output'
+- `--need_optim_ckpt`: whether to save optim ckpt, type=bool, default=False
+- `--max_epochs`: max epochs for training, type=int, default=3
+- `--dataset`: dataset name, type=str, choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static']
+- `--subset`: subset of the dataset, type=str, default=None
+- `--batch_size`: batch size while training, type=int, default=4
+- `--lora_rank`: low-rank adaptation matrices rank, type=int, default=0
+- `--loss_func`: which kind of loss function, choices=['log_sig', 'log_exp']
+- `--max_len`: max sentence length for generation, type=int, default=512
## Stage3 - Training model using prompts with RL
@@ -141,53 +208,89 @@ Stage3 uses reinforcement learning algorithm, which is the most complex part of
You can run the `examples/train_prompts.sh` to start PPO training.
+
You can also use the cmd following to start PPO training.
+[[Stage3 tutorial video]](https://www.youtube.com/watch?v=Z8wwSHxPL9g)
-```
+```bash
torchrun --standalone --nproc_per_node=4 train_prompts.py \
- --pretrain "/path/to/LLaMa-7B/" \
- --model 'llama' \
- --strategy colossalai_zero2 \
- --prompt_dataset /path/to/your/prompt_dataset \
- --pretrain_dataset /path/to/your/pretrain_dataset \
- --rm_pretrain /your/pretrain/rm/defination \
- --rm_path /your/rm/model/path
+ --pretrain "/path/to/LLaMa-7B/" \
+ --model 'llama' \
+ --strategy colossalai_zero2 \
+ --prompt_dataset /path/to/your/prompt_dataset \
+ --pretrain_dataset /path/to/your/pretrain_dataset \
+ --rm_pretrain /your/pretrain/rm/definition \
+ --rm_path /your/rm/model/path
```
-Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use the [script](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/example_data_reformat.py) to reformat [seed_prompts_ch.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_ch.jsonl) or [seed_prompts_en.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_en.jsonl) in InstructionWild.
+Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use the [script](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/generate_prompt_dataset.py) which samples `instinwild_en.json` or `instinwild_ch.json` in [InstructionWild](https://github.com/XueFuzhao/InstructionWild/tree/main/data#instructwild-data) to generate the prompt dataset.
Pretrain dataset: the pretrain dataset including the instruction and corresponding response, e.g. you can use the [InstructWild Data](https://github.com/XueFuzhao/InstructionWild/tree/main/data) in stage 1 supervised instructs tuning.
+**Note**: the required datasets follow the following format,
+
+- `pretrain dataset`
+
+ ```json
+ [
+ {
+ "instruction": "Provide a list of the top 10 most popular mobile games in Asia",
+ "input": "",
+ "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
+ "id": 0
+ },
+ ...
+ ]
+ ```
+
+- `prompt dataset`
+
+ ```json
+ [
+ {
+ "instruction": "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"",
+ "id": 0
+ },
+ {
+ "instruction": "Write a descriptive paragraph about a memorable vacation you went on",
+ "id": 1
+ },
+ ...
+ ]
+ ```
+
### Arg List
-- --strategy: the strategy using for training, choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2'
-- --model: model type of actor, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom'
-- --pretrain: pretrain model, type=str, default=None
-- --rm_model: reward model type, type=str, choices=['gpt2', 'bloom', 'opt', 'llama'], default=None
-- --rm_pretrain: pretrain model for reward model, type=str, default=None
-- --rm_path: the path of rm model, type=str, default=None
-- --save_path: path to save the model, type=str, default='output'
-- --prompt_dataset: path of the prompt dataset, type=str, default=None
-- --pretrain_dataset: path of the ptx dataset, type=str, default=None
-- --need_optim_ckpt: whether to save optim ckpt, type=bool, default=False
-- --num_episodes: num of episodes for training, type=int, default=10
-- --max_epochs: max epochs for training in one episode, type=int, default=5
-- --max_timesteps: max episodes in one batch, type=int, default=10
-- --update_timesteps: timesteps to update, type=int, default=10
-- --train_batch_size: batch size while training, type=int, default=8
-- --ptx_batch_size: batch size to compute ptx loss, type=int, default=1
-- --experience_batch_size: batch size to make experience, type=int, default=8
-- --lora_rank: low-rank adaptation matrices rank, type=int, default=0
-- --kl_coef: kl_coef using for computing reward, type=float, default=0.1
-- --ptx_coef: ptx_coef using for computing policy loss, type=float, default=0.9
+
+- `--strategy`: the strategy using for training, choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2'
+- `--model`: model type of actor, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom'
+- `--pretrain`: pretrain model, type=str, default=None
+- `--rm_model`: reward model type, type=str, choices=['gpt2', 'bloom', 'opt', 'llama'], default=None
+- `--rm_pretrain`: pretrain model for reward model, type=str, default=None
+- `--rm_path`: the path of rm model, type=str, default=None
+- `--save_path`: path to save the model, type=str, default='output'
+- `--prompt_dataset`: path of the prompt dataset, type=str, default=None
+- `--pretrain_dataset`: path of the ptx dataset, type=str, default=None
+- `--need_optim_ckpt`: whether to save optim ckpt, type=bool, default=False
+- `--num_episodes`: num of episodes for training, type=int, default=10
+- `--num_update_steps`: number of steps to update policy per episode, type=int
+- `--num_collect_steps`: number of steps to collect experience per episode, type=int
+- `--train_batch_size`: batch size while training, type=int, default=8
+- `--ptx_batch_size`: batch size to compute ptx loss, type=int, default=1
+- `--experience_batch_size`: batch size to make experience, type=int, default=8
+- `--lora_rank`: low-rank adaptation matrices rank, type=int, default=0
+- `--kl_coef`: kl_coef using for computing reward, type=float, default=0.1
+- `--ptx_coef`: ptx_coef using for computing policy loss, type=float, default=0.9
## Inference example - After Stage3
+
We support different inference options, including int8 and int4 quantization.
For details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
-
## Attention
+
The examples are demos for the whole training process.You need to change the hyper-parameters to reach great performance.
#### data
+
- [x] [rm-static](https://huggingface.co/datasets/Dahoas/rm-static)
- [x] [hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [ ] [openai/summarize_from_feedback](https://huggingface.co/datasets/openai/summarize_from_feedback)
@@ -197,14 +300,16 @@ The examples are demos for the whole training process.You need to change the hyp
## Support Model
### GPT
-- [x] GPT2-S (s)
-- [x] GPT2-M (m)
-- [x] GPT2-L (l)
-- [x] GPT2-XL (xl)
-- [x] GPT2-4B (4b)
-- [ ] GPT2-6B (6b)
+
+- [x] GPT2-S (s)
+- [x] GPT2-M (m)
+- [x] GPT2-L (l)
+- [x] GPT2-XL (xl)
+- [x] GPT2-4B (4b)
+- [ ] GPT2-6B (6b)
### BLOOM
+
- [x] [BLOOM-560m](https://huggingface.co/bigscience/bloom-560m)
- [x] [BLOOM-1b1](https://huggingface.co/bigscience/bloom-1b1)
- [x] [BLOOM-3b](https://huggingface.co/bigscience/bloom-3b)
@@ -212,6 +317,7 @@ The examples are demos for the whole training process.You need to change the hyp
- [ ] [BLOOM-175b](https://huggingface.co/bigscience/bloom)
### OPT
+
- [x] [OPT-125M](https://huggingface.co/facebook/opt-125m)
- [x] [OPT-350M](https://huggingface.co/facebook/opt-350m)
- [x] [OPT-1.3B](https://huggingface.co/facebook/opt-1.3b)
@@ -221,10 +327,11 @@ The examples are demos for the whole training process.You need to change the hyp
- [ ] [OPT-30B](https://huggingface.co/facebook/opt-30b)
### [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)
-- [x] LLaMA-7B
-- [x] LLaMA-13B
-- [ ] LLaMA-33B
-- [ ] LLaMA-65B
+
+- [x] LLaMA-7B
+- [x] LLaMA-13B
+- [ ] LLaMA-33B
+- [ ] LLaMA-65B
## Add your own models
@@ -237,12 +344,12 @@ if it is supported in huggingface [transformers](https://github.com/huggingface/
r you can build your own model by yourself.
### Actor model
-```
+
+```python
from ..base import Actor
from transformers.models.coati import CoatiModel
class CoatiActor(Actor):
-
def __init__(self,
pretrained: Optional[str] = None,
checkpoint: bool = False,
@@ -257,7 +364,8 @@ class CoatiActor(Actor):
```
### Reward model
-```
+
+```python
from ..base import RewardModel
from transformers.models.coati import CoatiModel
@@ -280,12 +388,11 @@ class CoatiRM(RewardModel):
### Critic model
-```
+```python
from ..base import Critic
from transformers.models.coati import CoatiModel
class CoatiCritic(Critic):
-
def __init__(self,
pretrained: Optional[str] = None,
checkpoint: bool = False,
diff --git a/applications/Chat/examples/community/README.md b/applications/Chat/examples/community/README.md
index c9c645032288f5beccd6b45f2d53f6ee5677c4e7..e14ac1767fc12aca65e0570827f960d6bfc0294b 100644
--- a/applications/Chat/examples/community/README.md
+++ b/applications/Chat/examples/community/README.md
@@ -1,5 +1,9 @@
+:warning: **This content may be outdated since the major update of Colossal Chat. We will update this content soon.**
+
# Community Examples
+
---
+
We are thrilled to announce the latest updates to ColossalChat, an open-source solution for cloning ChatGPT with a complete RLHF (Reinforcement Learning with Human Feedback) pipeline.
As Colossal-AI undergoes major updates, we are actively maintaining ColossalChat to stay aligned with the project's progress. With the introduction of Community-driven example, we aim to create a collaborative platform for developers to contribute exotic features built on top of ColossalChat.
@@ -14,11 +18,12 @@ For more information about community pipelines, please have a look at this [issu
Community examples consist of both inference and training examples that have been added by the community. Please have a look at the following table to get an overview of all community examples. Click on the Code Example to get a copy-and-paste ready code example that you can try out. If a community doesn't work as expected, please open an issue and ping the author on it.
-| Example | Description | Code Example | Colab | Author |
-|:---------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------:|
-| Peft | Adding Peft support for SFT and Prompts model training | [Huggingface Peft](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/community/peft) | - | [YY Lin](https://github.com/yynil) |
-| Train prompts on Ray | A Ray based implementation of Train prompts example | [Huggingface Peft](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/community/ray) | - | [MisterLin1995](https://github.com/MisterLin1995) |
-|...|...|...|...|...|
+| Example | Description | Code Example | Colab | Author |
+| :------------------- | :----------------------------------------------------- | :-------------------------------------------------------------------------------------------------------------- | :---- | ------------------------------------------------: |
+| Peft | Adding Peft support for SFT and Prompts model training | [Huggingface Peft](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/community/peft) | - | [YY Lin](https://github.com/yynil) |
+| Train prompts on Ray | A Ray based implementation of Train prompts example | [Training On Ray](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/community/ray) | - | [MisterLin1995](https://github.com/MisterLin1995) |
+| ... | ... | ... | ... | ... |
### How to get involved
+
To join our community-driven initiative, please visit the [ColossalChat GitHub repository](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples), review the provided information, and explore the codebase. To contribute, create a new issue outlining your proposed feature or enhancement, and our team will review and provide feedback. We look forward to collaborating with you on this exciting project!
diff --git a/applications/Chat/examples/community/peft/README.md b/applications/Chat/examples/community/peft/README.md
index eabb56fd8294ea89ad7632bd73ffece758e29e56..ada3a16296afab4facd94d92ab6ef0765f572b28 100644
--- a/applications/Chat/examples/community/peft/README.md
+++ b/applications/Chat/examples/community/peft/README.md
@@ -1,3 +1,5 @@
+:warning: **This content may be outdated since the major update of Colossal Chat. We will update this content soon.**
+
# Add Peft support for SFT and Prompts model training
The original implementation just adopts the loralib and merges the layers into the final model. The huggingface peft is a better lora model implementation and can be easily training and distributed.
@@ -5,7 +7,9 @@ The original implementation just adopts the loralib and merges the layers into t
Since reward model is relative small, I just keep it as original one. I suggest train full model to get the proper reward/critic model.
# Preliminary installation
+
Since the current pypi peft package(0.2) has some bugs, please install the peft package using source.
+
```
git clone https://github.com/huggingface/peft
cd peft
@@ -13,12 +17,14 @@ pip install .
```
# Usage
+
For SFT training, just call train_peft_sft.py
-Its arguments are almost identical to train_sft.py instead adding a new eval_dataset if you have a eval_dataset file. The data file is just a plain datafile, please check the format in the easy_dataset.py.
+Its arguments are almost identical to train_sft.py instead adding a new eval_dataset if you have an eval_dataset file. The data file is just a plain datafile, please check the format in the easy_dataset.py.
For stage-3 rlhf training, call train_peft_prompts.py.
-Its arguments are almost idential to train_prompts.py. The only difference is that I use text files to indicate the prompt and pretrained data file. The models are included in easy_models.py. Currently only bloom models are tested, but technically gpt2/opt/llama should be supported.
+Its arguments are almost identical to train_prompts.py. The only difference is that I use text files to indicate the prompt and pretrained data file. The models are included in easy_models.py. Currently only bloom models are tested, but technically gpt2/opt/llama should be supported.
# Dataformat
+
Please refer the formats in test_sft.txt, test_prompts.txt, test_pretrained.txt.
diff --git a/applications/Chat/examples/community/peft/easy_dataset.py b/applications/Chat/examples/community/peft/easy_dataset.py
index 24ea4f0a86186c4789062ab148c83c916ec86edc..d4b17689e9cb38dea77573cc3fb9d4912c47d8bd 100644
--- a/applications/Chat/examples/community/peft/easy_dataset.py
+++ b/applications/Chat/examples/community/peft/easy_dataset.py
@@ -3,7 +3,6 @@ import json
from typing import Dict, Sequence
import torch
-from datasets import load_dataset
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
@@ -20,7 +19,8 @@ def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer, max_length: i
padding="longest",
max_length=max_length,
truncation=True,
- ) for text in strings
+ )
+ for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
@@ -48,18 +48,17 @@ def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: AutoTo
class EasySupervisedDataset(Dataset):
-
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None:
super(EasySupervisedDataset, self).__init__()
with open(data_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines()
- #split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
+ # split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
sources, targets = [], []
for line in all_lines:
if "回答:" in line:
sep_index = line.index("回答:")
- sources.append(line[:sep_index + 3])
- targets.append(line[sep_index + 3:] + tokenizer.eos_token)
+ sources.append(line[: sep_index + 3])
+ targets.append(line[sep_index + 3 :] + tokenizer.eos_token)
else:
sources.append(line)
targets.append("" + tokenizer.eos_token)
@@ -83,15 +82,17 @@ class EasySupervisedDataset(Dataset):
class EasyPromptsDataset(Dataset):
-
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None:
super(EasyPromptsDataset, self).__init__()
with open(data_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines()
- all_lines = [line if "回答:" not in line else line[:line.index("回答:") + 3] for line in all_lines]
+ all_lines = [line if "回答:" not in line else line[: line.index("回答:") + 3] for line in all_lines]
self.prompts = [
- tokenizer(line, return_tensors='pt', max_length=max_length, padding='max_length',
- truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0)
+ tokenizer(line, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True)[
+ "input_ids"
+ ]
+ .to(torch.cuda.current_device())
+ .squeeze(0)
for line in tqdm(all_lines)
]
self.data_file = data_file
@@ -110,7 +111,6 @@ class EasyPromptsDataset(Dataset):
class EasyRewardDataset(Dataset):
-
def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None:
super(EasyRewardDataset, self).__init__()
self.chosen = []
@@ -120,44 +120,42 @@ class EasyRewardDataset(Dataset):
else:
self.end_token = special_token
print(self.end_token)
- #read all lines in the train_file to a list
+ # read all lines in the train_file to a list
with open(train_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines()
for line in tqdm(all_lines):
data = json.loads(line)
- prompt = "提问:" + data['prompt'] + " 回答:"
-
- chosen = prompt + data['chosen'] + self.end_token
- chosen_token = tokenizer(chosen,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
- self.chosen.append({
- "input_ids": chosen_token['input_ids'],
- "attention_mask": chosen_token['attention_mask']
- })
-
- reject = prompt + data['rejected'] + self.end_token
- reject_token = tokenizer(reject,
- max_length=max_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt")
- self.reject.append({
- "input_ids": reject_token['input_ids'],
- "attention_mask": reject_token['attention_mask']
- })
+ prompt = "提问:" + data["prompt"] + " 回答:"
+
+ chosen = prompt + data["chosen"] + self.end_token
+ chosen_token = tokenizer(
+ chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ self.chosen.append(
+ {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
+ )
+
+ reject = prompt + data["rejected"] + self.end_token
+ reject_token = tokenizer(
+ reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ self.reject.append(
+ {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
+ )
def __len__(self):
length = len(self.chosen)
return length
def __getitem__(self, idx):
- return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
- "input_ids"], self.reject[idx]["attention_mask"]
-
- #python representation of the object and the string representation of the object
+ return (
+ self.chosen[idx]["input_ids"],
+ self.chosen[idx]["attention_mask"],
+ self.reject[idx]["input_ids"],
+ self.reject[idx]["attention_mask"],
+ )
+
+ # python representation of the object and the string representation of the object
def __repr__(self):
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
@@ -165,30 +163,29 @@ class EasyRewardDataset(Dataset):
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
-'''
+"""
Easy SFT just accept a text file which can be read line by line. However the datasets will group texts together to max_length so LLM will learn the texts meaning better.
If individual lines are not related, just set is_group_texts to False.
-'''
+"""
class EasySFTDataset(Dataset):
-
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None:
super().__init__()
- #read the data_file line by line
+ # read the data_file line by line
with open(data_file, "r", encoding="UTF-8") as f:
- #encode the text data line by line and put raw python list input_ids only to raw_input_ids list
+ # encode the text data line by line and put raw python list input_ids only to raw_input_ids list
raw_input_ids = []
for line in f:
encoded_ids = tokenizer.encode(line)
- #if the encoded_ids is longer than max_length, then split it into several parts
+ # if the encoded_ids is longer than max_length, then split it into several parts
if len(encoded_ids) > max_length:
for i in range(0, len(encoded_ids), max_length):
- raw_input_ids.append(encoded_ids[i:i + max_length])
+ raw_input_ids.append(encoded_ids[i : i + max_length])
else:
raw_input_ids.append(encoded_ids)
- grouped_inpup_ids = []
+ grouped_input_ids = []
current_input_ids = []
attention_mask = []
if tokenizer.pad_token_id is None:
@@ -196,30 +193,33 @@ class EasySFTDataset(Dataset):
if is_group_texts:
for input_ids in raw_input_ids:
if len(current_input_ids) + len(input_ids) > max_length:
- #pad the current_input_ids to max_length with tokenizer.pad_token_id
+ # pad the current_input_ids to max_length with tokenizer.pad_token_id
padded_length = max_length - len(current_input_ids)
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
- grouped_inpup_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
+ grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
attention_mask.append(
- torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
+ torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
+ )
current_input_ids = []
else:
current_input_ids.extend(input_ids)
if len(current_input_ids) > 0:
padded_length = max_length - len(current_input_ids)
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
- grouped_inpup_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
+ grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
attention_mask.append(
- torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
+ torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
+ )
else:
- #just append the raw_input_ids to max_length
+ # just append the raw_input_ids to max_length
for input_ids in raw_input_ids:
padded_length = max_length - len(input_ids)
input_ids.extend([tokenizer.pad_token_id] * padded_length)
attention_mask.append(
- torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
- grouped_inpup_ids.append(torch.tensor(input_ids, dtype=torch.long))
- self.input_ids = grouped_inpup_ids
+ torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
+ )
+ grouped_input_ids.append(torch.tensor(input_ids, dtype=torch.long))
+ self.input_ids = grouped_input_ids
self.labels = copy.deepcopy(self.input_ids)
self.file_name = data_file
self.attention_mask = attention_mask
@@ -227,14 +227,14 @@ class EasySFTDataset(Dataset):
def __len__(self):
return len(self.input_ids)
- #get item from dataset
+ # get item from dataset
def __getitem__(self, idx):
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
- #generate the dataset description to be printed by print in python
+ # generate the dataset description to be printed by print in python
def __repr__(self):
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
- #generate the dataset description to be printed by print in python
+ # generate the dataset description to be printed by print in python
def __str__(self):
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
diff --git a/applications/Chat/examples/community/peft/easy_models.py b/applications/Chat/examples/community/peft/easy_models.py
index fe294868159dde227cae9757da41ee71b5778a25..db629e50ed94f6ba13b3d70b8646053c0dcc94b3 100644
--- a/applications/Chat/examples/community/peft/easy_models.py
+++ b/applications/Chat/examples/community/peft/easy_models.py
@@ -4,7 +4,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from coati.models.generation import generate
-from coati.models.utils import log_probs_from_logits, masked_mean
+from coati.models.utils import log_probs_from_logits
from peft import PeftModel
from torch.nn.modules import Module
from transformers import BloomConfig, BloomForCausalLM
@@ -24,38 +24,33 @@ class Actor(Module):
@torch.no_grad()
def generate(
- self,
- input_ids: torch.Tensor,
- return_action_mask: bool = True,
- **kwargs
+ self, input_ids: torch.Tensor, return_action_mask: bool = True, **kwargs
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
sequences = generate(self.model, input_ids, **kwargs)
attention_mask = None
- pad_token_id = kwargs.get('pad_token_id', None)
+ pad_token_id = kwargs.get("pad_token_id", None)
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
if not return_action_mask:
return sequences, attention_mask, None
input_len = input_ids.size(1)
- eos_token_id = kwargs.get('eos_token_id', None)
+ eos_token_id = kwargs.get("eos_token_id", None)
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
- action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
+ action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
- return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
+ return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len) :]
- def forward(self,
- sequences: torch.LongTensor,
- num_actions: int,
- attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
- """Returns action log probs
- """
+ def forward(
+ self, sequences: torch.LongTensor, num_actions: int, attention_mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """Returns action log probs"""
output = self.model(sequences, attention_mask=attention_mask)
- logits = output['logits']
+ logits = output["logits"]
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]
@@ -75,11 +70,13 @@ class BLOOMActor(Actor):
lora_train_bias (str): LoRA bias training mode.
"""
- def __init__(self,
- pretrained: str = None,
- config: Optional[BloomConfig] = None,
- checkpoint: bool = False,
- lora_path: str = None) -> None:
+ def __init__(
+ self,
+ pretrained: str = None,
+ config: Optional[BloomConfig] = None,
+ checkpoint: bool = False,
+ lora_path: str = None,
+ ) -> None:
if pretrained is not None:
model = BloomForCausalLM.from_pretrained(pretrained)
elif config is not None:
diff --git a/applications/Chat/examples/community/peft/train_peft_prompts.py b/applications/Chat/examples/community/peft/train_peft_prompts.py
index 0e277021e917a7da8ce5c3df18b0165e22b9a0e2..99a024f1463c34c8c77b9bea062ce12326be0ace 100644
--- a/applications/Chat/examples/community/peft/train_peft_prompts.py
+++ b/applications/Chat/examples/community/peft/train_peft_prompts.py
@@ -1,19 +1,16 @@
import argparse
-import pandas as pd
import torch
import torch.distributed as dist
-from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset
+from coati.dataset import DataCollatorForSupervisedDataset
from coati.models.bloom import BLOOMRM, BLOOMCritic
-from coati.models.gpt import GPTRM, GPTActor, GPTCritic
-from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
-from coati.models.opt import OPTRM, OPTActor, OPTCritic
+from coati.models.gpt import GPTRM, GPTCritic
+from coati.models.llama import LlamaCritic, LlamaRM
+from coati.models.opt import OPTRM, OPTCritic
from coati.trainer import PPOTrainer
-from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
-from coati.utils import prepare_llama_tokenizer_and_embedding
+from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from easy_dataset import EasyPromptsDataset, EasySupervisedDataset
from easy_models import BLOOMActor
-from peft import PeftModel
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
@@ -24,26 +21,24 @@ from colossalai.nn.optimizer import HybridAdam
def main(args):
# configure strategy
- if args.strategy == 'naive':
- strategy = NaiveStrategy()
- elif args.strategy == 'ddp':
+ if args.strategy == "ddp":
strategy = DDPStrategy()
- elif args.strategy == 'colossalai_gemini':
- strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
- elif args.strategy == 'colossalai_zero2':
- strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
+ elif args.strategy == "colossalai_gemini":
+ strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
+ elif args.strategy == "colossalai_zero2":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
if args.rm_path is not None:
- state_dict = torch.load(args.rm_path, map_location='cpu')
+ state_dict = torch.load(args.rm_path, map_location="cpu")
# configure model
- if args.model == 'bloom':
+ if args.model == "bloom":
# initial_model = BLOOMActor(pretrained=args.pretrain)
- print('Using peft lora to load Bloom model as inital_model')
+ print("Using peft lora to load Bloom model as initial_model")
initial_model = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
- print('Using peft lora to load Bloom model as initial_model (Done)')
+ print("Using peft lora to load Bloom model as initial_model (Done)")
else:
raise ValueError(f'Unsupported actor model "{args.model}"')
@@ -52,59 +47,59 @@ def main(args):
else:
rm_model_name = args.rm_model
- if rm_model_name == 'gpt2':
+ if rm_model_name == "gpt2":
reward_model = GPTRM(pretrained=args.rm_pretrain)
- elif rm_model_name == 'bloom':
+ elif rm_model_name == "bloom":
print("load bloom reward model ", args.rm_pretrain)
reward_model = BLOOMRM(pretrained=args.rm_pretrain)
- elif rm_model_name == 'opt':
+ elif rm_model_name == "opt":
reward_model = OPTRM(pretrained=args.rm_pretrain)
- elif rm_model_name == 'llama':
+ elif rm_model_name == "llama":
reward_model = LlamaRM(pretrained=args.rm_pretrain)
else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
if args.rm_path is not None:
- print('Loading reward model from', args.rm_path)
+ print("Loading reward model from", args.rm_path)
reward_model.load_state_dict(state_dict)
- if args.strategy != 'colossalai_gemini':
+ if args.strategy != "colossalai_gemini":
initial_model.to(torch.float16).to(torch.cuda.current_device())
reward_model.to(torch.float16).to(torch.cuda.current_device())
with strategy.model_init_context():
- if args.model == 'bloom':
+ if args.model == "bloom":
# actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
- print('Using peft lora to load Bloom model as Actor')
+ print("Using peft lora to load Bloom model as Actor")
actor = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
- print('Using peft lora to load Bloom model as Actor (Done)')
+ print("Using peft lora to load Bloom model as Actor (Done)")
else:
raise ValueError(f'Unsupported actor model "{args.model}"')
- if rm_model_name == 'gpt2':
+ if rm_model_name == "gpt2":
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
- elif rm_model_name == 'bloom':
+ elif rm_model_name == "bloom":
print("load bloom critic ", args.rm_pretrain, " lora_rank ", args.lora_rank, " use_action_mask ", True)
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
print("load bloom critic (Done) ")
- elif rm_model_name == 'opt':
+ elif rm_model_name == "opt":
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
- elif rm_model_name == 'llama':
+ elif rm_model_name == "llama":
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
if args.rm_path is not None:
- print('Loading reward model from', args.rm_path)
+ print("Loading reward model from", args.rm_path)
critic.load_state_dict(state_dict)
del state_dict
- if args.strategy != 'colossalai_gemini':
+ if args.strategy != "colossalai_gemini":
critic.to(torch.float16).to(torch.cuda.current_device())
actor.to(torch.float16).to(torch.cuda.current_device())
# configure optimizer
- if args.strategy.startswith('colossalai'):
+ if args.strategy.startswith("colossalai"):
actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
else:
@@ -112,23 +107,22 @@ def main(args):
critic_optim = Adam(critic.parameters(), lr=1e-7)
# configure tokenizer
- if args.model == 'gpt2':
+ if args.model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained(args.rm_pretrain)
- elif args.model == 'bloom':
+ tokenizer.pad_token = tokenizer.eos_token
+ elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained(args.rm_pretrain)
- elif args.model == 'opt':
+ tokenizer.pad_token = tokenizer.eos_token
+ elif args.model == "opt":
tokenizer = AutoTokenizer.from_pretrained(args.rm_pretrain)
- elif args.model == 'llama':
+ tokenizer.pad_token = tokenizer.eos_token
+ elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
- tokenizer.eos_token = '<\s>'
+ tokenizer.eos_token = "<\s>"
+ tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
- if args.model == 'llama':
- tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, actor)
- else:
- tokenizer.pad_token = tokenizer.eos_token
-
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
prompt_dataset = EasyPromptsDataset(args.prompt_path, tokenizer)
@@ -136,26 +130,27 @@ def main(args):
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
else:
prompt_sampler = None
- prompt_dataloader = DataLoader(prompt_dataset,
- shuffle=(prompt_sampler is None),
- sampler=prompt_sampler,
- batch_size=args.train_batch_size)
+ prompt_dataloader = DataLoader(
+ prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.train_batch_size
+ )
pretrain_dataset = EasySupervisedDataset(args.pretrain_dataset, tokenizer)
if dist.is_initialized() and dist.get_world_size() > 1:
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
else:
pretrain_sampler = None
- pretrain_dataloader = DataLoader(pretrain_dataset,
- shuffle=(pretrain_sampler is None),
- sampler=pretrain_sampler,
- batch_size=args.ptx_batch_size,
- collate_fn=data_collator)
+ pretrain_dataloader = DataLoader(
+ pretrain_dataset,
+ shuffle=(pretrain_sampler is None),
+ sampler=pretrain_sampler,
+ batch_size=args.ptx_batch_size,
+ collate_fn=data_collator,
+ )
def tokenize_fn(texts):
# MUST padding to max length to ensure inputs of all ranks have the same length
# Different length may lead to hang when using gemini, as different generation steps
- batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
+ batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()}
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
@@ -171,7 +166,6 @@ def main(args):
critic_optim,
kl_coef=args.kl_coef,
ptx_coef=args.ptx_coef,
- max_epochs=args.max_epochs,
train_batch_size=args.train_batch_size,
experience_batch_size=args.experience_batch_size,
tokenizer=tokenize_fn,
@@ -183,46 +177,46 @@ def main(args):
eos_token_id=tokenizer.eos_token_id,
)
- trainer.fit(prompt_dataloader=prompt_dataloader,
- pretrain_dataloader=pretrain_dataloader,
- num_episodes=args.num_episodes,
- max_timesteps=args.max_timesteps,
- update_timesteps=args.update_timesteps)
+ trainer.fit(
+ prompt_dataloader=prompt_dataloader,
+ pretrain_dataloader=pretrain_dataloader,
+ num_episodes=args.num_episodes,
+ num_update_steps=args.num_update_steps,
+ num_collect_steps=args.num_collect_steps,
+ )
# save model checkpoint after fitting
trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
- strategy.save_optimizer(actor_optim,
- 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
- only_rank0=False)
+ strategy.save_optimizer(
+ actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--prompt_path', type=str, default=None, help='path to the prompt dataset')
- parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset')
- parser.add_argument('--strategy',
- choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
- default='naive',
- help='strategy to use')
- parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--sft_lora_path', type=str, default=None)
- parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama'])
- parser.add_argument('--rm_path', type=str, default=None)
- parser.add_argument('--rm_pretrain', type=str, default=None)
- parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
- parser.add_argument('--need_optim_ckpt', type=bool, default=False)
- parser.add_argument('--num_episodes', type=int, default=10)
- parser.add_argument('--max_timesteps', type=int, default=10)
- parser.add_argument('--update_timesteps', type=int, default=10)
- parser.add_argument('--max_epochs', type=int, default=5)
- parser.add_argument('--train_batch_size', type=int, default=2)
- parser.add_argument('--ptx_batch_size', type=int, default=1)
- parser.add_argument('--experience_batch_size', type=int, default=8)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument('--kl_coef', type=float, default=0.1)
- parser.add_argument('--ptx_coef', type=float, default=0.9)
+ parser.add_argument("--prompt_path", type=str, default=None, help="path to the prompt dataset")
+ parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
+ parser.add_argument(
+ "--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp", help="strategy to use"
+ )
+ parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--sft_lora_path", type=str, default=None)
+ parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--rm_path", type=str, default=None)
+ parser.add_argument("--rm_pretrain", type=str, default=None)
+ parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
+ parser.add_argument("--need_optim_ckpt", type=bool, default=False)
+ parser.add_argument("--num_episodes", type=int, default=10)
+ parser.add_argument("--num_collect_steps", type=int, default=10)
+ parser.add_argument("--num_update_steps", type=int, default=5)
+ parser.add_argument("--train_batch_size", type=int, default=2)
+ parser.add_argument("--ptx_batch_size", type=int, default=1)
+ parser.add_argument("--experience_batch_size", type=int, default=8)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+ parser.add_argument("--kl_coef", type=float, default=0.1)
+ parser.add_argument("--ptx_coef", type=float, default=0.9)
args = parser.parse_args()
main(args)
diff --git a/applications/Chat/examples/community/peft/train_peft_sft.py b/applications/Chat/examples/community/peft/train_peft_sft.py
index 9bd0ebc12a836d6c699c90ce30538a74c65858f6..3bbef7208374c8ba831bb13fb6430ccb509889ab 100644
--- a/applications/Chat/examples/community/peft/train_peft_sft.py
+++ b/applications/Chat/examples/community/peft/train_peft_sft.py
@@ -1,19 +1,10 @@
import argparse
import os
-import loralib as lora
import torch
import torch.distributed as dist
-from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
-from coati.models.base import RewardModel
-from coati.models.bloom import BLOOMLM
-from coati.models.gpt import GPTLM
-from coati.models.llama import LlamaLM
-from coati.models.opt import OPTLM
from coati.trainer import SFTTrainer
-from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
-from coati.utils import prepare_llama_tokenizer_and_embedding
-from datasets import load_dataset
+from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from easy_dataset import EasyDataset
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
from torch.optim import Adam
@@ -30,80 +21,76 @@ from colossalai.tensor import ColoParameter
def train(args):
# configure strategy
- if args.strategy == 'naive':
- strategy = NaiveStrategy()
- elif args.strategy == 'ddp':
+ if args.strategy == "ddp":
strategy = DDPStrategy()
- elif args.strategy == 'colossalai_gemini':
- strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
- elif args.strategy == 'colossalai_zero2':
- strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
+ elif args.strategy == "colossalai_gemini":
+ strategy = GeminiStrategy(placement_policy="static")
+ elif args.strategy == "colossalai_zero2":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model
with strategy.model_init_context():
- print('Warning: currently only bloom is tested, gpt2,llama and opt are not tested')
+ print("Warning: currently only bloom is tested, gpt2,llama and opt are not tested")
model = AutoModelForCausalLM.from_pretrained(args.pretrain).to(torch.cuda.current_device())
- #if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json
- if os.path.exists(args.save_path) and os.path.exists(args.save_path+'/adapter_config.json') \
- and os.path.exists(args.save_path+'/adapter_model.bin'):
+ # if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json
+ if (
+ os.path.exists(args.save_path)
+ and os.path.exists(args.save_path + "/adapter_config.json")
+ and os.path.exists(args.save_path + "/adapter_model.bin")
+ ):
print("loading from saved peft model ", args.save_path)
model = PeftModel.from_pretrained(model, args.save_path)
else:
- #we'll use peft lora library to do the lora
+ # we'll use peft lora library to do the lora
lora_rank = args.lora_rank if args.lora_rank > 0 else 32
- #config lora with rank of lora_rank
- lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM,
- inference_mode=False,
- r=lora_rank,
- lora_alpha=32,
- lora_dropout=0.1)
+ # config lora with rank of lora_rank
+ lora_config = LoraConfig(
+ task_type=TaskType.CAUSAL_LM, inference_mode=False, r=lora_rank, lora_alpha=32, lora_dropout=0.1
+ )
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# configure tokenizer
- if args.model == 'gpt2':
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+ if args.model == "gpt2":
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'bloom':
- tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
+ elif args.model == "bloom":
+ tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'opt':
+ elif args.model == "opt":
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
- elif args.model == 'llama':
+ tokenizer.pad_token = tokenizer.eos_token
+ elif args.model == "llama":
tokenizer = AutoTokenizer.from_pretrained(
args.pretrain,
padding_side="right",
use_fast=False,
)
- tokenizer.eos_token = '<\s>'
+ tokenizer.eos_token = "<\s>"
+ tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
- tokenizer.pad_token = tokenizer.eos_token
- if args.model == 'llama':
- tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model)
-
- if args.strategy == 'colossalai_gemini':
- # this is a hack to deal with the resized embedding
- # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatiblity
- for name, param in model.named_parameters():
- if not isinstance(param, ColoParameter):
- sub_module_name = '.'.join(name.split('.')[:-1])
- weight_name = name.split('.')[-1]
- sub_module = model.get_submodule(sub_module_name)
- setattr(sub_module, weight_name, ColoParameter(param))
- else:
- tokenizer.pad_token = tokenizer.eos_token
+
+ if args.model == "llama" and args.strategy == "colossalai_gemini":
+ # this is a hack to deal with the resized embedding
+ # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility
+ for name, param in model.named_parameters():
+ if not isinstance(param, ColoParameter):
+ sub_module_name = ".".join(name.split(".")[:-1])
+ weight_name = name.split(".")[-1]
+ sub_module = model.get_submodule(sub_module_name)
+ setattr(sub_module, weight_name, ColoParameter(param))
# configure optimizer
- if args.strategy.startswith('colossalai'):
+ if args.strategy.startswith("colossalai"):
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
else:
optim = Adam(model.parameters(), lr=args.lr)
logger = get_dist_logger()
- logger.set_level('WARNING')
+ logger.set_level("WARNING")
# configure dataset
law_dataset = EasyDataset(args.dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
@@ -114,47 +101,57 @@ def train(args):
eval_dataset = EasyDataset(args.eval_dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
data_collator = default_collate
if dist.is_initialized() and dist.get_world_size() > 1:
- train_sampler = DistributedSampler(train_dataset,
- shuffle=True,
- seed=42,
- drop_last=True,
- rank=dist.get_rank(),
- num_replicas=dist.get_world_size())
+ train_sampler = DistributedSampler(
+ train_dataset,
+ shuffle=True,
+ seed=42,
+ drop_last=True,
+ rank=dist.get_rank(),
+ num_replicas=dist.get_world_size(),
+ )
if eval_dataset is not None:
- eval_sampler = DistributedSampler(eval_dataset,
- shuffle=False,
- seed=42,
- drop_last=False,
- rank=dist.get_rank(),
- num_replicas=dist.get_world_size())
+ eval_sampler = DistributedSampler(
+ eval_dataset,
+ shuffle=False,
+ seed=42,
+ drop_last=False,
+ rank=dist.get_rank(),
+ num_replicas=dist.get_world_size(),
+ )
else:
train_sampler = None
eval_sampler = None
- train_dataloader = DataLoader(train_dataset,
- shuffle=(train_sampler is None),
- sampler=train_sampler,
- batch_size=args.batch_size,
- collate_fn=data_collator,
- pin_memory=True)
+ train_dataloader = DataLoader(
+ train_dataset,
+ shuffle=(train_sampler is None),
+ sampler=train_sampler,
+ batch_size=args.batch_size,
+ collate_fn=data_collator,
+ pin_memory=True,
+ )
if eval_dataset is not None:
- eval_dataloader = DataLoader(eval_dataset,
- shuffle=(eval_sampler is None),
- sampler=eval_sampler,
- batch_size=args.batch_size,
- collate_fn=data_collator,
- pin_memory=True)
+ eval_dataloader = DataLoader(
+ eval_dataset,
+ shuffle=(eval_sampler is None),
+ sampler=eval_sampler,
+ batch_size=args.batch_size,
+ collate_fn=data_collator,
+ pin_memory=True,
+ )
else:
eval_dataloader = None
- trainer = SFTTrainer(model=model,
- strategy=strategy,
- optim=optim,
- train_dataloader=train_dataloader,
- eval_dataloader=eval_dataloader,
- batch_size=args.batch_size,
- max_epochs=args.max_epochs,
- accumulation_steps=args.accumulation_steps)
+ trainer = SFTTrainer(
+ model=model,
+ strategy=strategy,
+ optim=optim,
+ train_dataloader=train_dataloader,
+ eval_dataloader=eval_dataloader,
+ batch_size=args.batch_size,
+ max_epochs=args.max_epochs,
+ accumulation_steps=args.accumulation_steps,
+ )
trainer.fit(logger=logger, log_interval=args.log_interval)
@@ -162,29 +159,27 @@ def train(args):
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
- strategy.save_optimizer(trainer.optimizer,
- 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
- only_rank0=False)
+ strategy.save_optimizer(
+ trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--strategy',
- choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
- default='naive')
- parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--dataset', type=str, default=None)
- parser.add_argument('--eval_dataset', type=str, default=None)
- parser.add_argument('--save_path', type=str, default='output')
- parser.add_argument('--need_optim_ckpt', type=bool, default=False)
- parser.add_argument('--max_epochs', type=int, default=3)
- parser.add_argument('--batch_size', type=int, default=4)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
- parser.add_argument('--lr', type=float, default=5e-6)
- parser.add_argument('--accumulation_steps', type=int, default=8)
- parser.add_argument('--enable_peft_lora', action='store_true', default=False)
- parser.add_argument("--is_short_text", action='store_true', default=False)
+ parser.add_argument("--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp")
+ parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom")
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--dataset", type=str, default=None)
+ parser.add_argument("--eval_dataset", type=str, default=None)
+ parser.add_argument("--save_path", type=str, default="output")
+ parser.add_argument("--need_optim_ckpt", type=bool, default=False)
+ parser.add_argument("--max_epochs", type=int, default=3)
+ parser.add_argument("--batch_size", type=int, default=4)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+ parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log")
+ parser.add_argument("--lr", type=float, default=5e-6)
+ parser.add_argument("--accumulation_steps", type=int, default=8)
+ parser.add_argument("--enable_peft_lora", action="store_true", default=False)
+ parser.add_argument("--is_short_text", action="store_true", default=False)
args = parser.parse_args()
train(args)
diff --git a/applications/Chat/examples/community/ray/README.md b/applications/Chat/examples/community/ray/README.md
index 64360bd73ddc8d5627594332f85a7f701c26e1b9..a679a58336a7a661a9cecb21da1c5e8c58b4c80e 100644
--- a/applications/Chat/examples/community/ray/README.md
+++ b/applications/Chat/examples/community/ray/README.md
@@ -1,17 +1,31 @@
+:warning: **This content may be outdated since the major update of Colossal Chat. We will update this content soon.**
+
# ColossalAI on Ray
+
## Abstract
+
This is an experimental effort to run ColossalAI Chat training on Ray
+
## How to use?
+
### 1. Setup Ray clusters
+
Please follow the official [Ray cluster setup instructions](https://docs.ray.io/en/latest/cluster/getting-started.html) to setup an cluster with GPU support. Record the cluster's api server endpoint, it should be something similar to http://your.head.node.addrees:8265
+
### 2. Clone repo
+
Clone this project:
+
```shell
git clone https://github.com/hpcaitech/ColossalAI.git
```
+
### 3. Submit the ray job
+
```shell
python applications/Chat/examples/community/ray/ray_job_script.py http://your.head.node.addrees:8265
```
+
### 4. View your job on the Ray Dashboard
+
Open your ray cluster dashboard http://your.head.node.addrees:8265 to view your submitted training job.
diff --git a/applications/Chat/examples/community/ray/ray_job_script.py b/applications/Chat/examples/community/ray/ray_job_script.py
index 53f304d379fec54d82d3775552863e73f8dfcbc4..e8a1175a9c32ae2f87b4a3425dd4d13182e41f52 100644
--- a/applications/Chat/examples/community/ray/ray_job_script.py
+++ b/applications/Chat/examples/community/ray/ray_job_script.py
@@ -6,16 +6,25 @@ from ray.job_submission import JobSubmissionClient
def main(api_server_endpoint="http://127.0.0.1:8265"):
client = JobSubmissionClient(api_server_endpoint)
client.submit_job(
- entrypoint=
- "python experimental/ray/train_prompts_on_ray.py --strategy colossalai_zero2 --prompt_csv_url https://huggingface.co/datasets/fka/awesome-chatgpt-prompts/resolve/main/prompts.csv",
+ entrypoint="python experimental/ray/train_prompts_on_ray.py --strategy colossalai_zero2 --prompt_csv_url https://huggingface.co/datasets/fka/awesome-chatgpt-prompts/resolve/main/prompts.csv",
runtime_env={
- "working_dir":
- "applications/Chat",
+ "working_dir": "applications/Chat",
"pip": [
- "torch==1.13.1", "transformers>=4.20.1", "datasets", "loralib", "colossalai>=0.2.4", "langchain",
- "tokenizers", "fastapi", "sse_starlette", "wandb", "sentencepiece", "gpustat"
- ]
- })
+ "torch==1.13.1",
+ "transformers>=4.20.1",
+ "datasets",
+ "loralib",
+ "colossalai>=0.2.4",
+ "langchain",
+ "tokenizers",
+ "fastapi",
+ "sse_starlette",
+ "wandb",
+ "sentencepiece",
+ "gpustat",
+ ],
+ },
+ )
if __name__ == "__main__":
diff --git a/applications/Chat/examples/community/ray/train_prompts_on_ray.py b/applications/Chat/examples/community/ray/train_prompts_on_ray.py
index 289330ad841516a8bbc17ce80cf022d21f30643a..8abd83a8b249e52cbd14a01110b511d05a17a629 100644
--- a/applications/Chat/examples/community/ray/train_prompts_on_ray.py
+++ b/applications/Chat/examples/community/ray/train_prompts_on_ray.py
@@ -15,7 +15,7 @@ from coati.models.lora import LoRAModule
from coati.models.loss import PolicyLoss, ValueLoss
from coati.models.opt import OPTActor, OPTCritic
from coati.models.utils import compute_reward
-from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
+from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from torch.optim import Adam
@@ -26,9 +26,14 @@ from colossalai.nn.optimizer import HybridAdam
class ExperienceCompositionRefs:
-
- def __init__(self, sequences_attention_mask_action_mask_ref: ray.ObjectRef, action_log_probs_ref: ray.ObjectRef,
- base_action_log_probs_ref: ray.ObjectRef, value_ref: ray.ObjectRef, r_ref: ray.ObjectRef) -> None:
+ def __init__(
+ self,
+ sequences_attention_mask_action_mask_ref: ray.ObjectRef,
+ action_log_probs_ref: ray.ObjectRef,
+ base_action_log_probs_ref: ray.ObjectRef,
+ value_ref: ray.ObjectRef,
+ r_ref: ray.ObjectRef,
+ ) -> None:
self.sequences_attention_mask_action_mask_ref = sequences_attention_mask_action_mask_ref
self.action_log_probs_ref = action_log_probs_ref
self.base_action_log_probs_ref = base_action_log_probs_ref
@@ -37,14 +42,14 @@ class ExperienceCompositionRefs:
class ExperienceMaker:
-
def __init__(self, kl_coef) -> None:
self.kl_coef = kl_coef
@torch.no_grad()
def make_experience(self, experiment_computation_refs: ExperienceCompositionRefs):
sequences, attention_mask, action_mask = ray.get(
- experiment_computation_refs.sequences_attention_mask_action_mask_ref)
+ experiment_computation_refs.sequences_attention_mask_action_mask_ref
+ )
action_log_probs = ray.get(experiment_computation_refs.action_log_probs_ref)
base_action_log_probs = ray.get(experiment_computation_refs.base_action_log_probs_ref)
r = ray.get(experiment_computation_refs.r_ref)
@@ -58,11 +63,10 @@ class ExperienceMaker:
class DistributedTorchRayActor:
-
def __init__(self, world_size, rank, local_rank, master_addr, master_port):
- logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
- level=logging.INFO,
- datefmt='%Y-%m-%d %H:%M:%S')
+ logging.basicConfig(
+ format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
+ )
self._model = None
self._world_size = world_size
self._rank = rank
@@ -82,7 +86,7 @@ class DistributedTorchRayActor:
@staticmethod
def _get_free_port():
with socket.socket() as sock:
- sock.bind(('', 0))
+ sock.bind(("", 0))
return sock.getsockname()[1]
def get_master_addr_port(self):
@@ -90,7 +94,6 @@ class DistributedTorchRayActor:
class BasePPORole(DistributedTorchRayActor):
-
def add_experience_maker(self, kl_coef: float = 0.1):
self._experience_maker = ExperienceMaker(kl_coef)
@@ -99,19 +102,17 @@ class BasePPORole(DistributedTorchRayActor):
def _init_strategy(self, strategy: str):
# configure strategy
- if strategy == 'naive':
- self._strategy = NaiveStrategy()
- elif strategy == 'ddp':
+ if strategy == "ddp":
self._strategy = DDPStrategy()
- elif strategy == 'colossalai_gemini':
- self._strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
- elif strategy == 'colossalai_zero2':
- self._strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
+ elif strategy == "colossalai_gemini":
+ self._strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
+ elif strategy == "colossalai_zero2":
+ self._strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
raise ValueError(f'Unsupported strategy "{strategy}"')
def _init_optimizer(self):
- if isinstance(self._strategy, ColossalAIStrategy):
+ if isinstance(self._strategy, (GeminiStrategy, LowLevelZeroStrategy)):
self._optimizer = HybridAdam(self._model.parameters(), lr=5e-6)
else:
self._optimizer = Adam(self._model.parameters(), lr=5e-6)
@@ -126,11 +127,9 @@ class BasePPORole(DistributedTorchRayActor):
def _load_model_from_pretrained(self, model_class: Type[LoRAModule], pretrain: str):
raise NotImplementedError()
- def init_model_from_pretrained(self,
- strategy: str,
- model_class: Type[LoRAModule],
- pretrain: str,
- has_optimizer=False):
+ def init_model_from_pretrained(
+ self, strategy: str, model_class: Type[LoRAModule], pretrain: str, has_optimizer=False
+ ):
self._init_strategy(strategy)
self._load_model_from_pretrained(model_class, pretrain)
self._prepare_model_with_strategy(has_optimizer)
@@ -140,7 +139,6 @@ class BasePPORole(DistributedTorchRayActor):
class TrainablePPORole(BasePPORole):
-
def _load_model_from_pretrained(self, model_class, pretrain):
with self._strategy.model_init_context():
self._model = model_class(pretrain).to(torch.cuda.current_device())
@@ -163,38 +161,39 @@ class TrainablePPORole(BasePPORole):
@ray.remote(num_gpus=1)
class RayPPOActor(TrainablePPORole):
-
def set_loss_function(self, eps_clip: float):
self._actor_loss_fn = PolicyLoss(eps_clip)
def load_tokenizer_from_pretrained(self, model_type: str, pretrained):
- if model_type == 'gpt2':
+ if model_type == "gpt2":
self._model_tokenizer = GPT2Tokenizer.from_pretrained(pretrained)
self._model_tokenizer.pad_token = self._model_tokenizer.eos_token
- elif model_type == 'bloom':
+ elif model_type == "bloom":
self._model_tokenizer = BloomTokenizerFast.from_pretrained(pretrained)
self._model_tokenizer.pad_token = self._model_tokenizer.eos_token
- elif model_type == 'opt':
+ elif model_type == "opt":
self._model_tokenizer = AutoTokenizer.from_pretrained(pretrained)
else:
raise ValueError(f'Unsupported model "{model_type}"')
# Set tokenize function for sequence generation
def _text_input_tokenize_fn(texts):
- batch = self._model_tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True)
+ batch = self._model_tokenizer(texts, return_tensors="pt", max_length=96, padding=True, truncation=True)
return {k: v.cuda() for k, v in batch.items()}
self._sample_tokenize_function = _text_input_tokenize_fn
def setup_generate_kwargs(self, generate_kwargs: dict):
from coati.trainer.ppo import _set_default_generate_kwargs
+
self._generate_kwargs = _set_default_generate_kwargs(self._strategy, generate_kwargs, self._model)
- self._generate_kwargs['pad_token_id'] = self._model_tokenizer.pad_token_id
- self._generate_kwargs['eos_token_id'] = self._model_tokenizer.eos_token_id
+ self._generate_kwargs["pad_token_id"] = self._model_tokenizer.pad_token_id
+ self._generate_kwargs["eos_token_id"] = self._model_tokenizer.eos_token_id
def load_csv_prompt_file_from_url_to_sampler(self, prompt_url):
import pandas as pd
- prompts = pd.read_csv(prompt_url)['prompt']
+
+ prompts = pd.read_csv(prompt_url)["prompt"]
self._sampler = self._strategy.setup_sampler(prompts)
def _generate(self, input_ids, **generate_kwargs):
@@ -216,10 +215,9 @@ class RayPPOActor(TrainablePPORole):
def _training_step(self, experience):
num_actions = experience.action_mask.size(1)
action_log_probs = self._model(experience.sequences, num_actions, attention_mask=experience.attention_mask)
- actor_loss = self._actor_loss_fn(action_log_probs,
- experience.action_log_probs,
- experience.advantages,
- action_mask=experience.action_mask)
+ actor_loss = self._actor_loss_fn(
+ action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
+ )
self._strategy.backward(actor_loss, self._model, self._optimizer)
self._strategy.optimizer_step(self._optimizer)
self._optimizer.zero_grad()
@@ -231,17 +229,18 @@ class RayPPOActor(TrainablePPORole):
self._strategy.save_model(self._model, save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
if should_save_optimizer:
- self._strategy.save_optimizer(self._optimizer,
- 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
- only_rank0=False)
+ self._strategy.save_optimizer(
+ self._optimizer,
+ "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()),
+ only_rank0=False,
+ )
def generate_answer(self, prompt, max_length=30, num_return_sequences=5):
- encoded_input = self._model_tokenizer(prompt, return_tensors='pt')
+ encoded_input = self._model_tokenizer(prompt, return_tensors="pt")
input_ids = {k: v.cuda() for k, v in encoded_input.items()}
- sequence, _ = self._model.generate(**input_ids,
- max_length=max_length,
- return_action_mask=False,
- num_return_sequences=num_return_sequences)
+ sequence, _ = self._model.generate(
+ **input_ids, max_length=max_length, return_action_mask=False, num_return_sequences=num_return_sequences
+ )
token_list = list(sequence.data[0])
output = " ".join([self._model_tokenizer.decode(token) for token in token_list])
return output
@@ -249,18 +248,16 @@ class RayPPOActor(TrainablePPORole):
@ray.remote(num_gpus=1)
class RayPPOCritic(TrainablePPORole):
-
def set_loss_function(self, value_clip: float):
self._critic_loss_fn = ValueLoss(value_clip)
def _training_step(self, experience):
- values = self._model(experience.sequences,
- action_mask=experience.action_mask,
- attention_mask=experience.attention_mask)
- critic_loss = self._critic_loss_fn(values,
- experience.values,
- experience.reward,
- action_mask=experience.action_mask)
+ values = self._model(
+ experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
+ )
+ critic_loss = self._critic_loss_fn(
+ values, experience.values, experience.reward, action_mask=experience.action_mask
+ )
self._strategy.backward(critic_loss, self._model, self._optimizer)
self._strategy.optimizer_step(self._optimizer)
self._optimizer.zero_grad()
@@ -274,12 +271,12 @@ class RayPPOCritic(TrainablePPORole):
@ray.remote(num_gpus=1)
class RayPPORewardModel(BasePPORole):
-
def _load_model_from_pretrained(self, model_class, pretrain):
with self._strategy.model_init_context():
critic = model_class(pretrained=pretrain).to(torch.cuda.current_device())
- self._model = RewardModel(deepcopy(critic.model),
- deepcopy(critic.value_head)).to(torch.cuda.current_device())
+ self._model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(
+ torch.cuda.current_device()
+ )
@torch.no_grad()
def calculate_r(self, sequence_attention_action_mask):
@@ -289,7 +286,6 @@ class RayPPORewardModel(BasePPORole):
@ray.remote(num_gpus=1)
class RayPPOInitialModel(BasePPORole):
-
def _load_model_from_pretrained(self, model_class, pretrain):
with self._strategy.model_init_context():
self._model = model_class(pretrain).to(torch.cuda.current_device())
@@ -302,8 +298,8 @@ class RayPPOInitialModel(BasePPORole):
class PPORayActorGroup:
"""
- A group of ray actors
- Functions start with 'async' should return list of object refs
+ A group of ray actors
+ Functions start with 'async' should return list of object refs
"""
def __init__(self, num_nodes, num_gpus_per_node, ray_actor_type: Type[BasePPORole]) -> None:
@@ -321,8 +317,9 @@ class PPORayActorGroup:
pg = placement_group(bundles, strategy="STRICT_SPREAD")
ray.get(pg.ready())
if pg:
- master_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy(
- placement_group=pg, placement_group_bundle_index=0)).remote(world_size, 0, 0, None, None)
+ master_actor = self.ray_actor_type.options(
+ scheduling_strategy=PlacementGroupSchedulingStrategy(placement_group=pg, placement_group_bundle_index=0)
+ ).remote(world_size, 0, 0, None, None)
else:
master_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, 0, 0, None, None)
self._actor_handlers = [master_actor]
@@ -333,16 +330,20 @@ class PPORayActorGroup:
for rank in range(1, world_size):
local_rank = rank % self._num_gpus_per_node
if pg:
- worker_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy(
- placement_group=pg, placement_group_bundle_index=rank // self._num_gpus_per_node)).remote(
- world_size, rank, local_rank, master_addr, master_port)
+ worker_actor = self.ray_actor_type.options(
+ scheduling_strategy=PlacementGroupSchedulingStrategy(
+ placement_group=pg, placement_group_bundle_index=rank // self._num_gpus_per_node
+ )
+ ).remote(world_size, rank, local_rank, master_addr, master_port)
else:
- worker_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, rank, local_rank,
- master_addr, master_port)
+ worker_actor = self.ray_actor_type.options(num_gpus=1).remote(
+ world_size, rank, local_rank, master_addr, master_port
+ )
self._actor_handlers.append(worker_actor)
- def async_init_model_from_pretrained(self, strategy: str, model_class: Type[LoRAModule], pretrain: str,
- has_optimizer: bool):
+ def async_init_model_from_pretrained(
+ self, strategy: str, model_class: Type[LoRAModule], pretrain: str, has_optimizer: bool
+ ):
return [
actor.init_model_from_pretrained.remote(strategy, model_class, pretrain, has_optimizer)
for actor in self._actor_handlers
@@ -350,7 +351,6 @@ class PPORayActorGroup:
class TrainableModelRayActorGroup(PPORayActorGroup):
-
def async_learn_on_experiences(self, experience_refs):
num_actors = len(self._actor_handlers)
learn_result_refs = []
@@ -361,7 +361,6 @@ class TrainableModelRayActorGroup(PPORayActorGroup):
class PPOActorRayActorGroup(TrainableModelRayActorGroup):
-
def __init__(self, num_nodes, num_gpus_per_node) -> None:
super().__init__(num_nodes, num_gpus_per_node, RayPPOActor)
@@ -383,7 +382,8 @@ class PPOActorRayActorGroup(TrainableModelRayActorGroup):
action_log_probs_refs = []
for i in range(len(sequences_attention_mask_action_mask_refs)):
action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_action_log_probs.remote(
- sequences_attention_mask_action_mask_refs[i])
+ sequences_attention_mask_action_mask_refs[i]
+ )
action_log_probs_refs.append(action_log_probs_ref)
return action_log_probs_refs
@@ -395,7 +395,6 @@ class PPOActorRayActorGroup(TrainableModelRayActorGroup):
class PPOCriticRayActorGroup(TrainableModelRayActorGroup):
-
def __init__(self, num_nodes, num_gpus_per_node) -> None:
super().__init__(num_nodes, num_gpus_per_node, RayPPOCritic)
@@ -404,7 +403,8 @@ class PPOCriticRayActorGroup(TrainableModelRayActorGroup):
value_refs = []
for i in range(len(sequences_attention_mask_action_mask_refs)):
value_ref = self._actor_handlers[i % num_actors].calculate_value.remote(
- sequences_attention_mask_action_mask_refs[i])
+ sequences_attention_mask_action_mask_refs[i]
+ )
value_refs.append(value_ref)
return value_refs
@@ -413,7 +413,6 @@ class PPOCriticRayActorGroup(TrainableModelRayActorGroup):
class PPOInitialRayActorGroup(PPORayActorGroup):
-
def __init__(self, num_nodes, num_gpus_per_node) -> None:
super().__init__(num_nodes, num_gpus_per_node, RayPPOInitialModel)
@@ -422,13 +421,13 @@ class PPOInitialRayActorGroup(PPORayActorGroup):
base_action_log_probs_refs = []
for i in range(len(sequences_attention_mask_action_mask_refs)):
base_action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_base_action_log_probs.remote(
- sequences_attention_mask_action_mask_refs[i])
+ sequences_attention_mask_action_mask_refs[i]
+ )
base_action_log_probs_refs.append(base_action_log_probs_ref)
return base_action_log_probs_refs
class PPORewardRayActorGroup(PPORayActorGroup):
-
def __init__(self, num_nodes, num_gpus_per_node) -> None:
super().__init__(num_nodes, num_gpus_per_node, RayPPORewardModel)
@@ -437,20 +436,21 @@ class PPORewardRayActorGroup(PPORayActorGroup):
r_refs = []
for i in range(len(sequences_attention_mask_action_mask_refs)):
r_ref = self._actor_handlers[i % num_actors].calculate_r.remote(
- sequences_attention_mask_action_mask_refs[i])
+ sequences_attention_mask_action_mask_refs[i]
+ )
r_refs.append(r_ref)
return r_refs
def main(args):
- logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
- level=logging.INFO,
- datefmt='%Y-%m-%d %H:%M:%S')
- if args.model == 'gpt2':
+ logging.basicConfig(
+ format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
+ )
+ if args.model == "gpt2":
actor_model_class, critic_model_class = GPTActor, GPTCritic
- elif args.model == 'bloom':
+ elif args.model == "bloom":
actor_model_class, critic_model_class = BLOOMActor, BLOOMCritic
- elif args.model == 'opt':
+ elif args.model == "opt":
actor_model_class, critic_model_class = OPTActor, OPTCritic
else:
raise ValueError(f'Unsupported model "{args.model}"')
@@ -464,13 +464,14 @@ def main(args):
logging.info("Actors created")
# Prepare model for training
- generate_kwargs = {'max_length': 128, 'do_sample': True, 'temperature': 1.0, 'top_k': 50}
+ generate_kwargs = {"max_length": 128, "do_sample": True, "temperature": 1.0, "top_k": 50}
ray.get(
- actor_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, True) +
- critic_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, True) +
- initial_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, False) +
- reward_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, False) +
- actor_group.async_prepare_for_sequence_generation(args.model, args.pretrain, generate_kwargs))
+ actor_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, True)
+ + critic_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, True)
+ + initial_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, False)
+ + reward_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, False)
+ + actor_group.async_prepare_for_sequence_generation(args.model, args.pretrain, generate_kwargs)
+ )
logging.info("Models prepared for training")
# Prepare models for training
@@ -485,8 +486,12 @@ def main(args):
# Start training
logging.info("Training start")
# Set all models to eval and add experience maker
- all_ray_actors = actor_group._actor_handlers + critic_group._actor_handlers + \
- initial_group._actor_handlers + reward_group._actor_handlers
+ all_ray_actors = (
+ actor_group._actor_handlers
+ + critic_group._actor_handlers
+ + initial_group._actor_handlers
+ + reward_group._actor_handlers
+ )
num_ray_actors = len(all_ray_actors)
ray.get([ray_actor.eval.remote() for ray_actor in all_ray_actors])
ray.get([ray_actor.add_experience_maker.remote() for ray_actor in all_ray_actors])
@@ -499,18 +504,28 @@ def main(args):
time += 1
# Experience queueing stage
sequences_attention_mask_action_mask_refs = actor_group.async_sample_prompts_and_make_sequence(
- experience_batch_size)
+ experience_batch_size
+ )
base_action_log_probs_refs = initial_group.async_calculate_base_action_log_probs(
- sequences_attention_mask_action_mask_refs)
+ sequences_attention_mask_action_mask_refs
+ )
values_refs = critic_group.async_calculate_value(sequences_attention_mask_action_mask_refs)
r_refs = reward_group.async_calculate_r(sequences_attention_mask_action_mask_refs)
action_log_probs_refs = actor_group.async_calculate_action_log_probs(
- sequences_attention_mask_action_mask_refs)
- experience_composition_refs.extend([
- ExperienceCompositionRefs(sequences_attention_mask_action_mask_refs[i], action_log_probs_refs[i],
- base_action_log_probs_refs[i], values_refs[i], r_refs[i])
- for i in range(len(sequences_attention_mask_action_mask_refs))
- ])
+ sequences_attention_mask_action_mask_refs
+ )
+ experience_composition_refs.extend(
+ [
+ ExperienceCompositionRefs(
+ sequences_attention_mask_action_mask_refs[i],
+ action_log_probs_refs[i],
+ base_action_log_probs_refs[i],
+ values_refs[i],
+ r_refs[i],
+ )
+ for i in range(len(sequences_attention_mask_action_mask_refs))
+ ]
+ )
# Learning stage
if time % update_timesteps == 0:
experience_refs = []
@@ -521,8 +536,9 @@ def main(args):
experience_refs.append(selected_ray_actor.make_experience.remote(exp_composition_ref))
# backward
ray.get(
- actor_group.async_learn_on_experiences(experience_refs) +
- critic_group.async_learn_on_experiences(experience_refs))
+ actor_group.async_learn_on_experiences(experience_refs)
+ + critic_group.async_learn_on_experiences(experience_refs)
+ )
# clear refs queue
experience_composition_refs.clear()
logging.info("Training finished")
@@ -530,26 +546,24 @@ def main(args):
actor_group.save_checkpoint(args.save_path, args.need_optim_ckpt)
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--prompt_csv_url', type=str)
- parser.add_argument('--strategy',
- choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
- default='naive')
- parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
- parser.add_argument('--pretrain', type=str, default='gpt2')
- parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
- parser.add_argument('--need_optim_ckpt', type=bool, default=False)
- parser.add_argument('--num_episodes', type=int, default=10)
- parser.add_argument('--max_timesteps', type=int, default=10)
- parser.add_argument('--update_timesteps', type=int, default=10)
- parser.add_argument('--train_batch_size', type=int, default=8)
- parser.add_argument('--experience_batch_size', type=int, default=8)
- parser.add_argument('--num_actor_nodes', type=int, help='num of nodes to use to host actor model', default=1)
- parser.add_argument('--num_critic_nodes', type=int, help='num of nodes to use to host critic model', default=1)
- parser.add_argument('--num_initial_nodes', type=int, help='num of nodes to use to host initial model', default=1)
- parser.add_argument('--num_reward_nodes', type=int, help='num of nodes to use to host reward model', default=1)
- parser.add_argument('--num_gpus_per_node', type=int, help='num of gpus on a ray node', default=1)
+ parser.add_argument("--prompt_csv_url", type=str)
+ parser.add_argument("--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp")
+ parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt"])
+ parser.add_argument("--pretrain", type=str, default="gpt2")
+ parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts.pt")
+ parser.add_argument("--need_optim_ckpt", type=bool, default=False)
+ parser.add_argument("--num_episodes", type=int, default=10)
+ parser.add_argument("--max_timesteps", type=int, default=10)
+ parser.add_argument("--update_timesteps", type=int, default=10)
+ parser.add_argument("--train_batch_size", type=int, default=8)
+ parser.add_argument("--experience_batch_size", type=int, default=8)
+ parser.add_argument("--num_actor_nodes", type=int, help="num of nodes to use to host actor model", default=1)
+ parser.add_argument("--num_critic_nodes", type=int, help="num of nodes to use to host critic model", default=1)
+ parser.add_argument("--num_initial_nodes", type=int, help="num of nodes to use to host initial model", default=1)
+ parser.add_argument("--num_reward_nodes", type=int, help="num of nodes to use to host reward model", default=1)
+ parser.add_argument("--num_gpus_per_node", type=int, help="num of gpus on a ray node", default=1)
args = parser.parse_args()
ray.init()
main(args)
diff --git a/applications/Chat/examples/download_model.py b/applications/Chat/examples/download_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec3482b5f7895aa4e06b0924880dadeb3f0413e7
--- /dev/null
+++ b/applications/Chat/examples/download_model.py
@@ -0,0 +1,79 @@
+import argparse
+import dataclasses
+import os
+import parser
+from typing import List
+
+import tqdm
+from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
+from coati.models.gpt import GPTRM, GPTActor, GPTCritic
+from coati.models.opt import OPTRM, OPTActor, OPTCritic
+from huggingface_hub import hf_hub_download, snapshot_download
+from transformers import AutoConfig, AutoTokenizer, BloomConfig, BloomTokenizerFast, GPT2Config, GPT2Tokenizer
+
+
+@dataclasses.dataclass
+class HFRepoFiles:
+ repo_id: str
+ files: List[str]
+
+ def download(self, dir_path: str):
+ for file in self.files:
+ file_path = hf_hub_download(self.repo_id, file, local_dir=dir_path)
+
+ def download_all(self):
+ snapshot_download(self.repo_id)
+
+
+def test_init(model: str, dir_path: str):
+ if model == "gpt2":
+ config = GPT2Config.from_pretrained(dir_path)
+ actor = GPTActor(config=config)
+ critic = GPTCritic(config=config)
+ reward_model = GPTRM(config=config)
+ GPT2Tokenizer.from_pretrained(dir_path)
+ elif model == "bloom":
+ config = BloomConfig.from_pretrained(dir_path)
+ actor = BLOOMActor(config=config)
+ critic = BLOOMCritic(config=config)
+ reward_model = BLOOMRM(config=config)
+ BloomTokenizerFast.from_pretrained(dir_path)
+ elif model == "opt":
+ config = AutoConfig.from_pretrained(dir_path)
+ actor = OPTActor(config=config)
+ critic = OPTCritic(config=config)
+ reward_model = OPTRM(config=config)
+ AutoTokenizer.from_pretrained(dir_path)
+ else:
+ raise NotImplementedError(f"Model {model} not implemented")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-dir", type=str, default="test_models")
+ parser.add_argument("--config-only", default=False, action="store_true")
+ args = parser.parse_args()
+
+ if os.path.exists(args.model_dir):
+ print(f"[INFO]: {args.model_dir} already exists")
+ exit(0)
+
+ repo_list = {
+ "gpt2": HFRepoFiles(repo_id="gpt2", files=["config.json", "tokenizer.json", "vocab.json", "merges.txt"]),
+ "bloom": HFRepoFiles(
+ repo_id="bigscience/bloom-560m", files=["config.json", "tokenizer.json", "tokenizer_config.json"]
+ ),
+ "opt": HFRepoFiles(
+ repo_id="facebook/opt-350m", files=["config.json", "tokenizer_config.json", "vocab.json", "merges.txt"]
+ ),
+ }
+
+ os.mkdir(args.model_dir)
+ for model_name in tqdm.tqdm(repo_list):
+ dir_path = os.path.join(args.model_dir, model_name)
+ if args.config_only:
+ os.mkdir(dir_path)
+ repo_list[model_name].download(dir_path)
+ else:
+ repo_list[model_name].download_all()
+ test_init(model_name, dir_path)
diff --git a/applications/Chat/examples/example_data_reformat.py b/applications/Chat/examples/example_data_reformat.py
deleted file mode 100644
index dc83b29b525b16ff322126b63042eb32f32ed21e..0000000000000000000000000000000000000000
--- a/applications/Chat/examples/example_data_reformat.py
+++ /dev/null
@@ -1,12 +0,0 @@
-jsonl_file = 'seed_prompts_xx.jsonl' # seed_prompts_en.jsonl or seed_prompts_ch.json from InstructionWild
-reformat_file = 'prompts_xx.jsonl' # reformat jsonl file used as Prompt dataset in Stage3
-
-data = ''
-with open(jsonl_file, 'r', encoding="utf-8") as f1:
- for jsonstr in f1.readlines():
- jsonstr = '\t' + jsonstr.strip('\n') + ',\n'
- data = data + jsonstr
- data = '[\n' + data + ']'
-
-with open(reformat_file, 'w') as f2:
- f2.write(data)
\ No newline at end of file
diff --git a/applications/Chat/examples/generate_conversation_dataset.py b/applications/Chat/examples/generate_conversation_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e03b2d54260b3a0dc20324b596df0f561b2a581
--- /dev/null
+++ b/applications/Chat/examples/generate_conversation_dataset.py
@@ -0,0 +1,82 @@
+import argparse
+import json
+
+from datasets import load_dataset
+
+
+def generate_alpaca():
+ # We can convert dataset with the same format("instruction", "input", "output") as Alpaca into a one-round conversation.
+ conversation_dataset = []
+ dataset = load_dataset("tatsu-lab/alpaca", split="train")
+
+ instructions = dataset["instruction"]
+ inputs = dataset["input"]
+ outputs = dataset["output"]
+
+ assert len(instructions) == len(inputs) == len(outputs)
+
+ for idx in range(len(instructions)):
+ human_utterance = instructions[idx] + "\n\n" + inputs[idx] if inputs[idx] else instructions[idx]
+ human = {"from": "human", "value": human_utterance}
+
+ gpt_utterance = outputs[idx]
+ gpt = {"from": "gpt", "value": gpt_utterance}
+
+ conversation = dict(type="instruction", language="English", dataset="Alpaca", conversations=[human, gpt])
+ conversation_dataset.append(conversation)
+
+ return conversation_dataset
+
+
+def generate_sharegpt():
+ # ShareGPT data requires less processing.
+ conversation_dataset = []
+ dataset = load_dataset(
+ "anon8231489123/ShareGPT_Vicuna_unfiltered",
+ data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json",
+ split="train",
+ )
+
+ conversations = dataset["conversations"]
+
+ for idx in range(len(conversations)):
+ for conv in conversations[idx]:
+ # We don't need markdown and text value.
+ del conv["markdown"]
+ del conv["text"]
+
+ conversation = dict(
+ type="conversation", language="Multilingual", dataset="ShareGPT", conversations=conversations[idx]
+ )
+ conversation_dataset.append(conversation)
+
+ return conversation_dataset
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ default="All",
+ choices=["Alpaca", "ShareGPT", "All"],
+ help="which dataset to convert, All will combine Alpaca and ShareGPT",
+ )
+ parser.add_argument("--save_path", type=str, default="dataset.json", help="path to save the converted dataset")
+ args = parser.parse_args()
+
+ conversation_dataset = []
+
+ if args.dataset == "Alpaca":
+ conversation_dataset.extend(generate_alpaca())
+ elif args.dataset == "ShareGPT":
+ conversation_dataset.extend(generate_sharegpt())
+ else:
+ conversation_dataset.extend(generate_alpaca())
+ conversation_dataset.extend(generate_sharegpt())
+
+ for idx, sample in enumerate(conversation_dataset):
+ sample["id"] = idx + 1
+
+ with open(args.save_path, mode="w") as f:
+ json.dump(conversation_dataset, f, indent=4, default=str, ensure_ascii=False)
diff --git a/applications/Chat/examples/generate_prompt_dataset.py b/applications/Chat/examples/generate_prompt_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..4eec6feae505887331d88c204b5ad514438a6a31
--- /dev/null
+++ b/applications/Chat/examples/generate_prompt_dataset.py
@@ -0,0 +1,27 @@
+import argparse
+import json
+import random
+
+random.seed(42)
+
+
+def sample(args):
+ with open(args.dataset_path, mode="r") as f:
+ dataset_list = json.load(f)
+
+ sampled_dataset = [
+ {"instruction": sample["instruction"], "id": idx}
+ for idx, sample in enumerate(random.sample(dataset_list, args.sample_size))
+ ]
+
+ with open(args.save_path, mode="w") as f:
+ json.dump(sampled_dataset, f, indent=4, default=str, ensure_ascii=False)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--dataset_path", type=str, default=None, required=True, help="path to the pretrain dataset")
+ parser.add_argument("--save_path", type=str, default="prompt.json", help="path to save the prompt dataset")
+ parser.add_argument("--sample_size", type=int, default=16384, help="size of the prompt dataset")
+ args = parser.parse_args()
+ sample(args)
diff --git a/applications/Chat/examples/inference.py b/applications/Chat/examples/inference.py
index ae59d91c1822825924e87401ee5f5064cf41fbb6..62e06bf7b3bb75640fc4875ae984ef3b248faacd 100644
--- a/applications/Chat/examples/inference.py
+++ b/applications/Chat/examples/inference.py
@@ -2,63 +2,72 @@ import argparse
import torch
from coati.models.bloom import BLOOMActor
+from coati.models.generation import generate
from coati.models.gpt import GPTActor
+from coati.models.llama import LlamaActor
from coati.models.opt import OPTActor
-from coati.models.roberta import RoBERTaActor
-from transformers import AutoTokenizer, RobertaTokenizer
-from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
+from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
def eval(args):
# configure model
- if args.model == 'gpt2':
- actor = GPTActor(pretrained=args.pretrain).to(torch.cuda.current_device())
- elif args.model == 'bloom':
- actor = BLOOMActor(pretrained=args.pretrain).to(torch.cuda.current_device())
- elif args.model == 'opt':
- actor = OPTActor(pretrained=args.pretrain).to(torch.cuda.current_device())
- elif args.model == 'roberta':
- actor = RoBERTaActor(pretrained=args.pretrain).to(torch.cuda.current_device())
+ if args.model == "gpt2":
+ actor = GPTActor(pretrained=args.pretrain)
+ elif args.model == "bloom":
+ actor = BLOOMActor(pretrained=args.pretrain)
+ elif args.model == "opt":
+ actor = OPTActor(pretrained=args.pretrain)
+ elif args.model == "llama":
+ actor = LlamaActor(pretrained=args.pretrain)
else:
raise ValueError(f'Unsupported model "{args.model}"')
- state_dict = torch.load(args.model_path)
- actor.model.load_state_dict(state_dict)
+ actor.to(torch.cuda.current_device())
+ if args.model_path is not None:
+ state_dict = torch.load(args.model_path)
+ actor.load_state_dict(state_dict)
# configure tokenizer
- if args.model == 'gpt2':
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+ if args.model == "gpt2":
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'bloom':
- tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')
+ elif args.model == "bloom":
+ tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'opt':
- tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
- elif args.model == 'roberta':
- tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
+ elif args.model == "opt":
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
+ tokenizer.pad_token = tokenizer.eos_token
+ elif args.model == "llama":
+ tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
+ tokenizer.eos_token = "<\s>"
+ tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
actor.eval()
- input = args.input
- input_ids = tokenizer.encode(input, return_tensors='pt').to(torch.cuda.current_device())
- outputs = actor.generate(input_ids,
- max_length=args.max_length,
- do_sample=True,
- top_k=50,
- top_p=0.95,
- num_return_sequences=1)
+ tokenizer.padding_side = "left"
+ input_ids = tokenizer.encode(args.input, return_tensors="pt").to(torch.cuda.current_device())
+ outputs = generate(
+ actor,
+ input_ids,
+ tokenizer=tokenizer,
+ max_length=args.max_length,
+ do_sample=True,
+ top_k=50,
+ top_p=0.95,
+ num_return_sequences=1,
+ )
output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
- print(output)
+ print(f"[Output]: {''.join(output)}")
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'roberta'])
+ parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
# We suggest to use the pretrained model from HuggingFace, use pretrain to configure model
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--model_path', type=str, default=None)
- parser.add_argument('--input', type=str, default='Question: How are you ? Answer:')
- parser.add_argument('--max_length', type=int, default=100)
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--model_path", type=str, default=None)
+ parser.add_argument("--input", type=str, default="Question: How are you ? Answer:")
+ parser.add_argument("--max_length", type=int, default=100)
args = parser.parse_args()
eval(args)
diff --git a/applications/Chat/examples/ray/1mmt_prompt.py b/applications/Chat/examples/ray/1mmt_prompt.py
new file mode 100644
index 0000000000000000000000000000000000000000..8de6219ec4e9fbdc471d0ac5f2a1c8f67efc41dc
--- /dev/null
+++ b/applications/Chat/examples/ray/1mmt_prompt.py
@@ -0,0 +1,181 @@
+import argparse
+import os
+import socket
+from functools import partial
+
+import pandas as pd
+import ray
+from coati.quant import llama_load_quant, low_resource_init
+from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
+from coati.ray.experience_maker_holder import ExperienceMakerHolder
+from coati.ray.utils import (
+ get_actor_from_args,
+ get_critic_from_args,
+ get_reward_model_from_args,
+ get_strategy_from_args,
+ get_tokenizer_from_args,
+)
+from torch.utils.data import DataLoader
+from transformers import AutoConfig
+from transformers.modeling_utils import no_init_weights
+
+
+def get_free_port():
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.bind(("", 0))
+ return s.getsockname()[1]
+
+
+def get_local_ip():
+ with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
+ s.connect(("8.8.8.8", 80))
+ return s.getsockname()[0]
+
+
+def main(args):
+ master_addr = str(get_local_ip())
+ # trainer_env_info
+ trainer_port = str(get_free_port())
+ env_info_trainers = [
+ {
+ "local_rank": "0",
+ "rank": str(rank),
+ "world_size": str(args.num_trainers),
+ "master_port": trainer_port,
+ "master_addr": master_addr,
+ }
+ for rank in range(args.num_trainers)
+ ]
+
+ # maker_env_info
+ maker_port = str(get_free_port())
+ env_info_maker = {
+ "local_rank": "0",
+ "rank": "0",
+ "world_size": "1",
+ "master_port": maker_port,
+ "master_addr": master_addr,
+ }
+
+ # configure tokenizer
+ tokenizer = get_tokenizer_from_args(args.model)
+
+ def trainer_model_fn():
+ actor = get_actor_from_args(args.model, args.pretrain).half().cuda()
+ critic = get_critic_from_args(args.model, args.critic_pretrain).half().cuda()
+ return actor, critic
+
+ # configure Trainer
+ trainer_refs = [
+ DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
+ experience_maker_holder_name_list=["maker1"],
+ strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
+ model_fn=trainer_model_fn,
+ env_info=env_info_trainer,
+ train_batch_size=args.train_batch_size,
+ buffer_limit=16,
+ eval_performance=True,
+ debug=args.debug,
+ update_lora_weights=not (args.lora_rank == 0),
+ )
+ for i, env_info_trainer in enumerate(env_info_trainers)
+ ]
+
+ def model_fn():
+ actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
+ critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
+ reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
+ if args.initial_model_quant_ckpt is not None and args.model == "llama":
+ # quantize initial model
+ actor_cfg = AutoConfig.from_pretrained(args.pretrain)
+ with low_resource_init(), no_init_weights():
+ initial_model = get_actor_from_args(args.model, config=actor_cfg)
+ initial_model.model = (
+ llama_load_quant(
+ initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
+ )
+ .cuda()
+ .requires_grad_(False)
+ )
+ else:
+ initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
+ return actor, critic, reward_model, initial_model
+
+ # configure Experience Maker
+ experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote(
+ detached_trainer_name_list=[f"trainer{i}" for i in range(args.num_trainers)],
+ strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
+ model_fn=model_fn,
+ env_info=env_info_maker,
+ experience_batch_size=args.experience_batch_size,
+ kl_coef=0.1,
+ debug=args.debug,
+ update_lora_weights=not (args.lora_rank == 0),
+ # sync_models_from_trainers=True,
+ # generation kwargs:
+ max_length=512,
+ do_sample=True,
+ temperature=1.0,
+ top_k=50,
+ pad_token_id=tokenizer.pad_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ eval_performance=True,
+ use_cache=True,
+ )
+
+ # uncomment this function if sync_models_from_trainers is True
+ # ray.get([
+ # trainer_ref.sync_models_to_remote_makers.remote()
+ # for trainer_ref in trainer_refs
+ # ])
+
+ wait_tasks = []
+
+ total_steps = args.experience_batch_size * args.experience_steps // (args.num_trainers * args.train_batch_size)
+ for trainer_ref in trainer_refs:
+ wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
+
+ dataset_size = args.experience_batch_size * 4
+
+ def build_dataloader():
+ def tokenize_fn(texts):
+ batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
+ return {k: v.cuda() for k, v in batch.items()}
+
+ dataset = pd.read_csv(args.prompt_path)["prompt"]
+ dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn)
+ return dataloader
+
+ wait_tasks.append(experience_holder_ref.workingloop.remote(build_dataloader, num_steps=args.experience_steps))
+
+ ray.get(wait_tasks)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--prompt_path", type=str, default=None)
+ parser.add_argument("--num_trainers", type=int, default=1)
+ parser.add_argument(
+ "--trainer_strategy",
+ choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
+ default="ddp",
+ )
+ parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
+ parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--critic_pretrain", type=str, default=None)
+ parser.add_argument("--experience_steps", type=int, default=4)
+ parser.add_argument("--experience_batch_size", type=int, default=8)
+ parser.add_argument("--train_epochs", type=int, default=1)
+ parser.add_argument("--update_steps", type=int, default=2)
+ parser.add_argument("--train_batch_size", type=int, default=8)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+
+ parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
+ parser.add_argument("--quant_bits", type=int, default=4)
+ parser.add_argument("--quant_group_size", type=int, default=128)
+ parser.add_argument("--debug", action="store_true")
+ args = parser.parse_args()
+ ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
+ main(args)
diff --git a/applications/Chat/examples/ray/mmmt_prompt.py b/applications/Chat/examples/ray/mmmt_prompt.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c03a0468b020921872ff726dde7db9db9a828c7
--- /dev/null
+++ b/applications/Chat/examples/ray/mmmt_prompt.py
@@ -0,0 +1,201 @@
+import argparse
+import os
+import socket
+from functools import partial
+
+import pandas as pd
+import ray
+from coati.quant import llama_load_quant, low_resource_init
+from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
+from coati.ray.experience_maker_holder import ExperienceMakerHolder
+from coati.ray.utils import (
+ get_actor_from_args,
+ get_critic_from_args,
+ get_receivers_per_sender,
+ get_reward_model_from_args,
+ get_strategy_from_args,
+)
+from torch.utils.data import DataLoader
+from transformers import AutoConfig, AutoTokenizer
+from transformers.modeling_utils import no_init_weights
+
+
+def get_free_port():
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.bind(("", 0))
+ return s.getsockname()[1]
+
+
+def get_local_ip():
+ with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
+ s.connect(("8.8.8.8", 80))
+ return s.getsockname()[0]
+
+
+def main(args):
+ master_addr = str(get_local_ip())
+ # trainer_env_info
+ trainer_port = str(get_free_port())
+ env_info_trainers = [
+ {
+ "local_rank": "0",
+ "rank": str(rank),
+ "world_size": str(args.num_trainers),
+ "master_port": trainer_port,
+ "master_addr": master_addr,
+ }
+ for rank in range(args.num_trainers)
+ ]
+
+ # maker_env_info
+ maker_port = str(get_free_port())
+ env_info_makers = [
+ {
+ "local_rank": "0",
+ "rank": str(rank),
+ "world_size": str(args.num_makers),
+ "master_port": maker_port,
+ "master_addr": master_addr,
+ }
+ for rank in range(args.num_makers)
+ ]
+
+ # configure tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
+ tokenizer.pad_token = tokenizer.eos_token
+
+ def model_fn():
+ actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
+ critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
+ reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
+ if args.initial_model_quant_ckpt is not None and args.model == "llama":
+ # quantize initial model
+ actor_cfg = AutoConfig.from_pretrained(args.pretrain)
+ with low_resource_init(), no_init_weights():
+ initial_model = get_actor_from_args(args.model, config=actor_cfg)
+ initial_model.model = (
+ llama_load_quant(
+ initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
+ )
+ .cuda()
+ .requires_grad_(False)
+ )
+ else:
+ initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
+ return actor, critic, reward_model, initial_model
+
+ # configure Experience Maker
+ experience_holder_refs = [
+ ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote(
+ detached_trainer_name_list=[
+ f"trainer{x}"
+ for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False)
+ ],
+ strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
+ model_fn=model_fn,
+ env_info=env_info_maker,
+ kl_coef=0.1,
+ debug=args.debug,
+ update_lora_weights=not (args.lora_rank == 0),
+ # sync_models_from_trainers=True,
+ # generation kwargs:
+ max_length=512,
+ do_sample=True,
+ temperature=1.0,
+ top_k=50,
+ pad_token_id=tokenizer.pad_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ eval_performance=True,
+ use_cache=True,
+ )
+ for i, env_info_maker in enumerate(env_info_makers)
+ ]
+
+ def trainer_model_fn():
+ actor = get_actor_from_args(args.model, args.pretrain, lora_rank=args.lora_rank).half().cuda()
+ critic = get_critic_from_args(args.model, args.critic_pretrain, lora_rank=args.lora_rank).half().cuda()
+ return actor, critic
+
+ # configure Trainer
+ trainer_refs = [
+ DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
+ experience_maker_holder_name_list=[
+ f"maker{x}"
+ for x in get_receivers_per_sender(i, args.num_trainers, args.num_makers, allow_idle_sender=True)
+ ],
+ strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
+ model_fn=trainer_model_fn,
+ env_info=env_info_trainer,
+ train_batch_size=args.train_batch_size,
+ buffer_limit=16,
+ eval_performance=True,
+ debug=args.debug,
+ update_lora_weights=not (args.lora_rank == 0),
+ )
+ for i, env_info_trainer in enumerate(env_info_trainers)
+ ]
+
+ dataset_size = args.experience_batch_size * 4
+
+ def build_dataloader():
+ def tokenize_fn(texts):
+ batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
+ return {k: v.cuda() for k, v in batch.items()}
+
+ dataset = pd.read_csv(args.prompt_path)["prompt"]
+ dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn)
+ return dataloader
+
+ # uncomment this function if sync_models_from_trainers is True
+ # ray.get([
+ # trainer_ref.sync_models_to_remote_makers.remote()
+ # for trainer_ref in trainer_refs
+ # ])
+
+ wait_tasks = []
+
+ for experience_holder_ref in experience_holder_refs:
+ wait_tasks.append(experience_holder_ref.workingloop.remote(build_dataloader, num_steps=args.experience_steps))
+
+ total_steps = (
+ args.experience_batch_size
+ * args.experience_steps
+ * args.num_makers
+ // (args.num_trainers * args.train_batch_size)
+ )
+ for trainer_ref in trainer_refs:
+ wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
+
+ ray.get(wait_tasks)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--prompt_path", type=str, default=None)
+ parser.add_argument("--num_makers", type=int, default=1)
+ parser.add_argument("--num_trainers", type=int, default=1)
+ parser.add_argument(
+ "--trainer_strategy",
+ choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
+ default="ddp",
+ )
+ parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
+ parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--critic_pretrain", type=str, default=None)
+ parser.add_argument("--experience_steps", type=int, default=4)
+ parser.add_argument("--experience_batch_size", type=int, default=8)
+ parser.add_argument("--train_epochs", type=int, default=1)
+ parser.add_argument("--update_steps", type=int, default=2)
+ parser.add_argument("--train_batch_size", type=int, default=8)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+
+ parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
+ parser.add_argument("--quant_bits", type=int, default=4)
+ parser.add_argument("--quant_group_size", type=int, default=128)
+ parser.add_argument("--debug", action="store_true")
+ args = parser.parse_args()
+
+ ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
+ main(args)
diff --git a/applications/Chat/examples/ray/requirements.txt b/applications/Chat/examples/ray/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e0275631807fc1c22cd3593cd1c48d29d1bbaff9
--- /dev/null
+++ b/applications/Chat/examples/ray/requirements.txt
@@ -0,0 +1 @@
+ray
diff --git a/applications/Chat/examples/ray/test_ci.sh b/applications/Chat/examples/ray/test_ci.sh
new file mode 100755
index 0000000000000000000000000000000000000000..895f7de0fea94c5a58081a52983f2dc393ff9780
--- /dev/null
+++ b/applications/Chat/examples/ray/test_ci.sh
@@ -0,0 +1,12 @@
+#!/bin/bash
+
+set -xe
+BASE=$(realpath $(dirname $0))
+
+export RAY_NAMESPACE=admin
+export DATA=/data/scratch/chatgpt/prompts.csv
+
+# install requirements
+pip install -r ${BASE}/requirements.txt
+
+python ${BASE}/mmmt_prompt.py --prompt_path $DATA --num_makers 2 --num_trainers 2 --trainer_strategy colossalai_gemini --model opt --critic_model opt --pretrain facebook/opt-350m --critic_pretrain facebook/opt-125m --experience_batch_size 4 --train_batch_size 2
diff --git a/applications/Chat/examples/requirements.txt b/applications/Chat/examples/requirements.txt
index 40e6edc7ea7303c516ededa9ecd360f0445f957d..5474dfa16b3ef8c88c2a8073163ec9d08e03bf86 100644
--- a/applications/Chat/examples/requirements.txt
+++ b/applications/Chat/examples/requirements.txt
@@ -1,2 +1,3 @@
pandas>=1.4.1
sentencepiece
+colossalai==0.3.3
diff --git a/applications/Chat/examples/test_ci.sh b/applications/Chat/examples/test_ci.sh
deleted file mode 100755
index 2b049163c8012f0d7954a805be41990ecab1d910..0000000000000000000000000000000000000000
--- a/applications/Chat/examples/test_ci.sh
+++ /dev/null
@@ -1,126 +0,0 @@
-#!/usr/bin/env bash
-
-set -xue
-
-if [ -z "$SFT_DATASET" ]; then
- echo "Please set \$SFT_DATASET to the path to sft dataset."
- exit 1
-fi
-
-if [ -z "$PROMPT_PATH" ]; then
- echo "Please set \$PROMPT_PATH to the path to prompts csv."
- exit 1
-fi
-
-if [ -z "$PRETRAIN_DATASET" ]; then
- echo "Please set \$PRETRAIN_DATASET to the path to alpaca data."
- exit 1
-fi
-
-BASE=$(realpath $(dirname $0))
-
-export OMP_NUM_THREADS=8
-
-# install requirements
-pip install -r ${BASE}/requirements.txt
-
-wandb init -m offline
-
-# train sft
-torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'bigscience/bloom-560m' \
- --model 'bloom' --strategy colossalai_zero2 --lora_rank 4\
- --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
- --save_path ${BASE}/output
-rm -rf ${BASE}/output
-
-torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \
- --model 'gpt2' --strategy colossalai_zero2 \
- --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
- --save_path ${BASE}/output
-rm -rf ${BASE}/output
-
-torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'facebook/opt-350m' \
- --model 'opt' --strategy colossalai_zero2 --lora_rank 4\
- --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
- --save_path ${BASE}/output
-rm -rf ${BASE}/output
-
-torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \
- --model 'gpt2' --strategy ddp --lora_rank 4\
- --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
- --save_path ${BASE}/output
-
-#torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'facebook/opt-350m' \
-# --model 'opt' --strategy naive \
-# --dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
-# --save_path ${BASE}/output
-
-rm -rf ${BASE}/output
-
-# train rm
-torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
- --pretrain 'facebook/opt-350m' --model 'opt' \
- --strategy colossalai_zero2 --loss_fn 'log_sig'\
- --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
- --test True --lora_rank 0 \
- --save_path ${BASE}/rm_ckpt_opt.pt
-
-torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
- --pretrain 'gpt2' --model 'gpt2' \
- --strategy colossalai_zero2 --loss_fn 'log_exp' \
- --dataset 'Dahoas/rm-static' \
- --test True --lora_rank 0 \
- --save_path ${BASE}/rm_ckpt_gpt.pt
-
-torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
- --pretrain 'gpt2' --model 'gpt2' \
- --strategy ddp --loss_fn 'log_exp' \
- --dataset 'Dahoas/rm-static' \
- --test True --lora_rank 4 \
- --save_path ${BASE}/rm_ckpt.pt
-rm -rf ${BASE}/rm_ckpt.pt
-
-torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
- --pretrain 'bigscience/bloom-560m' --model 'bloom' \
- --strategy colossalai_zero2 --loss_fn 'log_sig' \
- --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
- --test True --lora_rank 4 \
- --save_path ${BASE}/rm_ckpt.pt
-rm -rf ${BASE}/rm_ckpt.pt
-
-torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
- --pretrain 'microsoft/deberta-v3-large' --model 'deberta' \
- --strategy colossalai_zero2 --loss_fn 'log_sig' \
- --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
- --test True --lora_rank 4 \
- --save_path ${BASE}/rm_ckpt.pt
-rm -rf ${BASE}/rm_ckpt.pt
-
-torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
- --pretrain 'roberta-base' --model 'roberta' \
- --strategy colossalai_zero2 --loss_fn 'log_exp'\
- --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
- --test True --lora_rank 4 \
- --save_path ${BASE}/rm_ckpt.pt
-
-rm -rf ${BASE}/rm_ckpt.pt
-
-torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
- --strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
- --update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
- --pretrain 'facebook/opt-350m' --model opt \
- --rm_pretrain 'facebook/opt-350m' \
- --rm_path ${BASE}/rm_ckpt_opt.pt \
- --save_path ${BASE}/actor_checkpoint_prompts.pt
-rm -rf ${BASE}/rm_ckpt_opt.pt
-
-torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
- --strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
- --update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
- --pretrain 'gpt2' --model gpt2 \
- --rm_pretrain 'gpt2' \
- --rm_path ${BASE}/rm_ckpt_gpt.pt \
- --save_path ${BASE}/actor_checkpoint_prompts.pt
-rm -rf ${BASE}/rm_ckpt_gpt.pt
-
-rm -rf ${BASE}/actor_checkpoint_prompts.pt
diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py
index a584991cd34e00cfb9bf2c97a285e642b25268c5..8868e278d85e9fa2a5a5ce2c0a74b57d6103a1fd 100644
--- a/applications/Chat/examples/train_prompts.py
+++ b/applications/Chat/examples/train_prompts.py
@@ -1,169 +1,169 @@
import argparse
+import warnings
-import pandas as pd
import torch
import torch.distributed as dist
-from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset
+from coati.dataset import PromptDataset, SupervisedDataset
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTActor, OPTCritic
-from coati.models.roberta import RoBERTaActor, RoBERTaCritic, RoBERTaRM
from coati.trainer import PPOTrainer
-from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
-from coati.utils import prepare_llama_tokenizer_and_embedding
+from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
-from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer, RobertaTokenizer
+from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
from colossalai.nn.optimizer import HybridAdam
def main(args):
# configure strategy
- if args.strategy == 'naive':
- strategy = NaiveStrategy()
- elif args.strategy == 'ddp':
+ if args.strategy == "ddp":
strategy = DDPStrategy()
- elif args.strategy == 'colossalai_gemini':
- strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
- elif args.strategy == 'colossalai_zero2':
- strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
+ elif args.strategy == "colossalai_gemini":
+ strategy = GeminiStrategy(placement_policy="static", initial_scale=2**5)
+ elif args.strategy == "colossalai_zero2":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
if args.rm_path is not None:
- state_dict = torch.load(args.rm_path, map_location='cpu')
-
- # configure model
- if args.model == 'gpt2':
- initial_model = GPTActor(pretrained=args.pretrain)
- elif args.model == 'bloom':
- initial_model = BLOOMActor(pretrained=args.pretrain)
- elif args.model == 'opt':
- initial_model = OPTActor(pretrained=args.pretrain)
- elif args.model == 'llama':
- initial_model = LlamaActor(pretrained=args.pretrain)
- elif args.model == 'roberta':
- initial_model = RoBERTaActor(pretrained=args.pretrain)
- else:
- raise ValueError(f'Unsupported actor model "{args.model}"')
+ warnings.warn("LoRA weights should be merged with the model weights")
+ state_dict = torch.load(args.rm_path, map_location="cpu")
- if args.rm_model == None:
- rm_model_name = args.model
- else:
- rm_model_name = args.rm_model
-
- if rm_model_name == 'gpt2':
- reward_model = GPTRM(pretrained=args.rm_pretrain)
- elif rm_model_name == 'bloom':
- reward_model = BLOOMRM(pretrained=args.rm_pretrain)
- elif rm_model_name == 'opt':
- reward_model = OPTRM(pretrained=args.rm_pretrain)
- elif rm_model_name == 'llama':
- reward_model = LlamaRM(pretrained=args.rm_pretrain)
- elif rm_model_name == 'roberta':
- reward_model = RoBERTaRM(pretrained=args.rm_pretrain)
- else:
- raise ValueError(f'Unsupported reward model "{rm_model_name}"')
+ if args.lora_rank > 0:
+ warnings.warn("Lora is not supported yet.")
+ args.lora_rank = 0
- if args.rm_path is not None:
- reward_model.load_state_dict(state_dict)
+ with strategy.model_init_context():
+ # configure model
+ if args.model == "gpt2":
+ initial_model = GPTActor(pretrained=args.pretrain)
+ elif args.model == "bloom":
+ initial_model = BLOOMActor(pretrained=args.pretrain)
+ elif args.model == "opt":
+ initial_model = OPTActor(pretrained=args.pretrain)
+ elif args.model == "llama":
+ initial_model = LlamaActor(pretrained=args.pretrain)
+ else:
+ raise ValueError(f'Unsupported actor model "{args.model}"')
+
+ if args.rm_model is None:
+ rm_model_name = args.model
+ else:
+ rm_model_name = args.rm_model
+
+ if rm_model_name == "gpt2":
+ reward_model = GPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
+ elif rm_model_name == "bloom":
+ reward_model = BLOOMRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
+ elif rm_model_name == "opt":
+ reward_model = OPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
+ elif rm_model_name == "llama":
+ reward_model = LlamaRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
+ else:
+ raise ValueError(f'Unsupported reward model "{rm_model_name}"')
+
+ if args.rm_path is not None:
+ reward_model.load_state_dict(state_dict, strict=False)
- initial_model.to(torch.float16).to(torch.cuda.current_device())
- reward_model.to(torch.float16).to(torch.cuda.current_device())
+ initial_model.to(torch.bfloat16).to(torch.cuda.current_device())
+ reward_model.to(torch.bfloat16).to(torch.cuda.current_device())
- with strategy.model_init_context():
- if args.model == 'gpt2':
+ if args.model == "gpt2":
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
- elif args.model == 'bloom':
+ elif args.model == "bloom":
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
- elif args.model == 'opt':
+ elif args.model == "opt":
actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
- elif args.model == 'llama':
+ elif args.model == "llama":
actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
- elif args.model == 'roberta':
- actor = RoBERTaActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
else:
raise ValueError(f'Unsupported actor model "{args.model}"')
- if rm_model_name == 'gpt2':
- critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
- elif rm_model_name == 'bloom':
- critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
- elif rm_model_name == 'opt':
- critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
- elif rm_model_name == 'llama':
- critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
- elif rm_model_name == 'roberta':
- critic = RoBERTaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
+ if rm_model_name == "gpt2":
+ critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
+ elif rm_model_name == "bloom":
+ critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
+ elif rm_model_name == "opt":
+ critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
+ elif rm_model_name == "llama":
+ critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
if args.rm_path is not None:
- critic.load_state_dict(state_dict)
+ critic.load_state_dict(state_dict, strict=False)
del state_dict
- if args.strategy != 'colossalai_gemini':
- critic.to(torch.float16).to(torch.cuda.current_device())
- actor.to(torch.float16).to(torch.cuda.current_device())
+ actor.to(torch.bfloat16).to(torch.cuda.current_device())
+ critic.to(torch.bfloat16).to(torch.cuda.current_device())
# configure optimizer
- if args.strategy.startswith('colossalai'):
- actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
- critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
+ if args.strategy.startswith("colossalai"):
+ actor_optim = HybridAdam(actor.parameters(), lr=args.lr)
+ critic_optim = HybridAdam(critic.parameters(), lr=args.lr)
else:
- actor_optim = Adam(actor.parameters(), lr=1e-7)
- critic_optim = Adam(critic.parameters(), lr=1e-7)
+ actor_optim = Adam(actor.parameters(), lr=args.lr)
+ critic_optim = Adam(critic.parameters(), lr=args.lr)
# configure tokenizer
- if args.model == 'gpt2':
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
- elif args.model == 'bloom':
- tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
- elif args.model == 'opt':
- tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
- elif args.model == 'llama':
- tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
- tokenizer.eos_token = '<\s>'
- elif args.model == 'roberta':
- tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
+ if args.model == "gpt2":
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
+ tokenizer.pad_token = tokenizer.eos_token
+ elif args.model == "bloom":
+ tokenizer = BloomTokenizerFast.from_pretrained(
+ "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
+ )
+ tokenizer.pad_token = tokenizer.eos_token
+ elif args.model == "opt":
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
+ tokenizer.pad_token = tokenizer.eos_token
+ elif args.model == "llama":
+ tokenizer = LlamaTokenizer.from_pretrained(
+ "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
+ )
+ tokenizer.eos_token = "<\s>"
+ tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
-
- if args.model == 'llama':
- tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, actor)
- else:
- tokenizer.pad_token = tokenizer.eos_token
-
- data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
-
- prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_dataset, max_datasets_size=16384)
+ # NOTE: generate() requires padding_side to be "left"
+ tokenizer.padding_side = "left"
+
+ prompt_dataset = PromptDataset(
+ tokenizer=tokenizer,
+ data_path=args.prompt_dataset,
+ max_datasets_size=args.max_datasets_size,
+ max_length=args.max_input_len,
+ )
if dist.is_initialized() and dist.get_world_size() > 1:
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
else:
prompt_sampler = None
- prompt_dataloader = DataLoader(prompt_dataset,
- shuffle=(prompt_sampler is None),
- sampler=prompt_sampler,
- batch_size=args.experience_batch_size)
-
- pretrain_dataset = SupervisedDataset(tokenizer=tokenizer,
- data_path=args.pretrain_dataset,
- max_datasets_size=16384,
- max_length=args.max_input_len)
+ prompt_dataloader = DataLoader(
+ prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.experience_batch_size
+ )
+
+ pretrain_dataset = SupervisedDataset(
+ tokenizer=tokenizer,
+ data_path=args.pretrain_dataset,
+ max_datasets_size=args.max_datasets_size,
+ max_length=args.max_input_len,
+ )
if dist.is_initialized() and dist.get_world_size() > 1:
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
else:
pretrain_sampler = None
- pretrain_dataloader = DataLoader(pretrain_dataset,
- shuffle=(pretrain_sampler is None),
- sampler=pretrain_sampler,
- batch_size=args.ptx_batch_size,
- collate_fn=data_collator)
+ pretrain_dataloader = DataLoader(
+ pretrain_dataset, shuffle=(pretrain_sampler is None), sampler=pretrain_sampler, batch_size=args.ptx_batch_size
+ )
- (actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
+ # NOTE: For small models like opt-1.3b, reward model and initial model are not required to be parallelized.
+ (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
+ (actor, actor_optim), (critic, critic_optim), reward_model, initial_model
+ )
# configure trainer
trainer = PPOTrainer(
@@ -174,60 +174,76 @@ def main(args):
initial_model,
actor_optim,
critic_optim,
+ tokenizer=tokenizer,
kl_coef=args.kl_coef,
ptx_coef=args.ptx_coef,
- max_epochs=args.max_epochs,
train_batch_size=args.train_batch_size,
max_length=args.max_seq_len,
use_cache=True,
do_sample=True,
temperature=1.0,
top_k=50,
- pad_token_id=tokenizer.pad_token_id,
- eos_token_id=tokenizer.eos_token_id,
+ offload_inference_models=args.strategy != "colossalai_gemini",
)
- trainer.fit(prompt_dataloader=prompt_dataloader,
- pretrain_dataloader=pretrain_dataloader,
- num_episodes=args.num_episodes,
- max_timesteps=args.max_timesteps,
- update_timesteps=args.update_timesteps)
+ trainer.fit(
+ num_episodes=args.num_episodes,
+ num_collect_steps=args.num_collect_steps,
+ num_update_steps=args.num_update_steps,
+ prompt_dataloader=prompt_dataloader,
+ pretrain_dataloader=pretrain_dataloader,
+ log_dir=args.log_dir,
+ use_wandb=args.use_wandb,
+ )
+
+ if args.lora_rank > 0 and args.merge_lora_weights:
+ from coati.models.lora import LORA_MANAGER
+ # NOTE: set model to eval to merge LoRA weights
+ LORA_MANAGER.merge_weights = True
+ actor.eval()
# save model checkpoint after fitting
- strategy.save_model(actor, args.save_path, only_rank0=True)
+ strategy.save_pretrained(actor, path=args.save_path)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
- strategy.save_optimizer(actor_optim,
- 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
- only_rank0=False)
+ strategy.save_optimizer(
+ actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--prompt_dataset', type=str, default=None, help='path to the prompt dataset')
- parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset')
- parser.add_argument('--strategy',
- choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
- default='colossalai_zero2',
- help='strategy to use')
- parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama', 'roberta'])
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama', 'roberta'])
- parser.add_argument('--rm_path', type=str, default=None)
- parser.add_argument('--rm_pretrain', type=str, default=None)
- parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
- parser.add_argument('--need_optim_ckpt', type=bool, default=False)
- parser.add_argument('--num_episodes', type=int, default=10)
- parser.add_argument('--max_timesteps', type=int, default=10)
- parser.add_argument('--update_timesteps', type=int, default=10)
- parser.add_argument('--max_epochs', type=int, default=5)
- parser.add_argument('--train_batch_size', type=int, default=8)
- parser.add_argument('--ptx_batch_size', type=int, default=1)
- parser.add_argument('--experience_batch_size', type=int, default=8)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument('--kl_coef', type=float, default=0.1)
- parser.add_argument('--ptx_coef', type=float, default=0.9)
- parser.add_argument('--max_input_len', type=int, default=96)
- parser.add_argument('--max_seq_len', type=int, default=128)
+ parser.add_argument("--prompt_dataset", type=str, default=None, help="path to the prompt dataset")
+ parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
+ parser.add_argument("--max_datasets_size", type=int, default=50000)
+ parser.add_argument(
+ "--strategy",
+ choices=["ddp", "colossalai_gemini", "colossalai_zero2"],
+ default="colossalai_zero2",
+ help="strategy to use",
+ )
+ parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--tokenizer", type=str, default=None)
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"])
+ parser.add_argument("--rm_path", type=str, default=None)
+ parser.add_argument("--rm_pretrain", type=str, default=None)
+ parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
+ parser.add_argument("--need_optim_ckpt", type=bool, default=False)
+ parser.add_argument("--num_episodes", type=int, default=10)
+ parser.add_argument("--num_collect_steps", type=int, default=10)
+ parser.add_argument("--num_update_steps", type=int, default=5)
+ parser.add_argument("--train_batch_size", type=int, default=8)
+ parser.add_argument("--ptx_batch_size", type=int, default=1)
+ parser.add_argument("--experience_batch_size", type=int, default=8)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+ parser.add_argument("--merge_lora_weights", type=bool, default=True)
+ parser.add_argument("--lr", type=float, default=1e-7)
+ parser.add_argument("--kl_coef", type=float, default=0.1)
+ parser.add_argument("--ptx_coef", type=float, default=0.9)
+ parser.add_argument("--max_input_len", type=int, default=96)
+ parser.add_argument("--max_seq_len", type=int, default=128)
+ parser.add_argument("--log_dir", default="logs", type=str)
+ parser.add_argument("--use_wandb", default=False, action="store_true")
args = parser.parse_args()
main(args)
diff --git a/applications/Chat/examples/train_prompts.sh b/applications/Chat/examples/train_prompts.sh
index 7f3b2636ca32862d03a260c44bfa4765f6f9990e..d04c416015b1566ee197941328cbc44f376f4545 100755
--- a/applications/Chat/examples/train_prompts.sh
+++ b/applications/Chat/examples/train_prompts.sh
@@ -1,13 +1,13 @@
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
- local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
- | tail -n +2 \
- | nl -v 0 \
- | tee /dev/tty \
- | sort -g -k 2 \
- | awk '{print $1}' \
- | head -n $n)
+ local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
+ tail -n +2 |
+ nl -v 0 |
+ tee /dev/tty |
+ sort -g -k 2 |
+ awk '{print $1}' |
+ head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
@@ -17,4 +17,9 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 2
# torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2
-torchrun --standalone --nproc_per_node=2 train_prompts.py --prompt_dataset /path/to/data.json --strategy colossalai_zero2
+torchrun --standalone --nproc_per_node=2 train_prompts.py \
+ --pretrain_dataset /path/to/data.json \
+ --prompt_dataset /path/to/data.json \
+ --strategy colossalai_zero2 \
+ --num_episodes 1 --num_collect_steps 2 --num_update_steps 1 \
+ --train_batch_size 2
diff --git a/applications/Chat/examples/train_reward_model.py b/applications/Chat/examples/train_reward_model.py
index 48b12336fa6743714add52ee52c8a518c155f2ab..df6e8b6bdc26209dbbecf9023e9230aa34834393 100644
--- a/applications/Chat/examples/train_reward_model.py
+++ b/applications/Chat/examples/train_reward_model.py
@@ -1,26 +1,22 @@
import argparse
-from random import randint
+import warnings
-import loralib as lora
import torch
import torch.distributed as dist
from coati.dataset import HhRlhfDataset, RmStaticDataset
from coati.models import LogExpLoss, LogSigLoss
-from coati.models.base import RewardModel
from coati.models.bloom import BLOOMRM
-from coati.models.deberta import DebertaRM
from coati.models.gpt import GPTRM
from coati.models.llama import LlamaRM
from coati.models.opt import OPTRM
-from coati.models.roberta import RoBERTaRM
from coati.trainer import RewardModelTrainer
-from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
-from coati.utils import prepare_llama_tokenizer_and_embedding
+from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from datasets import load_dataset
from torch.optim import Adam
+from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
-from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer, LlamaTokenizer, RobertaTokenizer
+from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from colossalai.nn.optimizer import HybridAdam
@@ -28,72 +24,69 @@ from colossalai.nn.optimizer import HybridAdam
def train(args):
# configure strategy
- if args.strategy == 'naive':
- strategy = NaiveStrategy()
- elif args.strategy == 'ddp':
+ if args.strategy == "ddp":
strategy = DDPStrategy()
- elif args.strategy == 'colossalai_gemini':
- strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
- elif args.strategy == 'colossalai_zero2':
- strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
+ elif args.strategy == "colossalai_gemini":
+ strategy = GeminiStrategy(placement_policy="auto")
+ elif args.strategy == "colossalai_zero2":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model
+ if args.lora_rank > 0:
+ warnings.warn("Lora is not supported yet.")
+ args.lora_rank = 0
+
with strategy.model_init_context():
- if args.model == 'bloom':
- model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
- elif args.model == 'opt':
- model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
- elif args.model == 'gpt2':
- model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
- elif args.model == 'deberta':
- model = DebertaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
- elif args.model == 'llama':
- model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
- elif args.model == 'roberta':
- model = RoBERTaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
+ if args.model == "bloom":
+ model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
+ elif args.model == "opt":
+ model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
+ elif args.model == "gpt2":
+ model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
+ elif args.model == "llama":
+ model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
else:
raise ValueError(f'Unsupported model "{args.model}"')
+ model.to(torch.bfloat16).to(torch.cuda.current_device())
+
if args.model_path is not None:
state_dict = torch.load(args.model_path)
model.load_state_dict(state_dict)
- model = model.to(torch.float16)
-
# configure tokenizer
- if args.model == 'gpt2':
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
- elif args.model == 'bloom':
- tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
- elif args.model == 'opt':
- tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
- elif args.model == 'deberta':
- tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-large')
- elif args.model == 'llama':
- tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
- elif args.model == 'roberta':
- tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
+ if args.model == "gpt2":
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
+ tokenizer.pad_token = tokenizer.eos_token
+ elif args.model == "bloom":
+ tokenizer = BloomTokenizerFast.from_pretrained(
+ "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
+ )
+ tokenizer.pad_token = tokenizer.eos_token
+ elif args.model == "opt":
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
+ tokenizer.pad_token = tokenizer.eos_token
+ elif args.model == "llama":
+ tokenizer = LlamaTokenizer.from_pretrained(
+ "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
+ )
+ tokenizer.eos_token = "<\s>"
+ tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
- max_len = args.max_len
-
- if args.model == 'llama':
- tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model)
- else:
- tokenizer.pad_token = tokenizer.eos_token
# configure optimizer
- if args.strategy.startswith('colossalai'):
- optim = HybridAdam(model.parameters(), lr=5e-6)
+ if args.strategy.startswith("colossalai"):
+ optim = HybridAdam(model.parameters(), lr=args.lr)
else:
- optim = Adam(model.parameters(), lr=5e-6)
+ optim = Adam(model.parameters(), lr=args.lr)
# configure loss function
- if args.loss_fn == 'log_sig':
+ if args.loss_fn == "log_sig":
loss_fn = LogSigLoss()
- elif args.loss_fn == 'log_exp':
+ elif args.loss_fn == "log_exp":
loss_fn = LogExpLoss()
else:
raise ValueError(f'Unsupported loss function "{args.loss_fn}"')
@@ -104,107 +97,112 @@ def train(args):
else:
data = load_dataset(args.dataset)
- if args.test:
- train_data = data['train'].select(range(100))
- eval_data = data['test'].select(range(10))
- else:
- train_data = data['train']
- eval_data = data['test']
- valid_data = data['test'].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5)))
-
- if args.dataset == 'Dahoas/rm-static':
- train_dataset = RmStaticDataset(train_data, tokenizer, max_len)
- valid_dataset = RmStaticDataset(valid_data, tokenizer, max_len)
- eval_dataset = RmStaticDataset(eval_data, tokenizer, max_len)
- elif args.dataset == 'Anthropic/hh-rlhf':
- train_dataset = HhRlhfDataset(train_data, tokenizer, max_len)
- valid_dataset = HhRlhfDataset(valid_data, tokenizer, max_len)
- eval_dataset = HhRlhfDataset(eval_data, tokenizer, max_len)
+ train_data = data["train"].select(range(min(args.max_datasets_size, len(data["train"]))))
+ eval_data = data["test"].select(range(min(args.max_datasets_size, len(data["test"]))))
+
+ if args.dataset == "Dahoas/rm-static":
+ train_dataset = RmStaticDataset(train_data, tokenizer, args.max_len)
+ eval_dataset = RmStaticDataset(eval_data, tokenizer, args.max_len)
+ elif args.dataset == "Anthropic/hh-rlhf":
+ train_dataset = HhRlhfDataset(train_data, tokenizer, args.max_len)
+ eval_dataset = HhRlhfDataset(eval_data, tokenizer, args.max_len)
else:
raise ValueError(f'Unsupported dataset "{args.dataset}"')
if dist.is_initialized() and dist.get_world_size() > 1:
- train_sampler = DistributedSampler(train_dataset,
- shuffle=True,
- seed=42,
- drop_last=True,
- rank=dist.get_rank(),
- num_replicas=dist.get_world_size())
- valid_sampler = DistributedSampler(valid_dataset,
- shuffle=True,
- seed=42,
- drop_last=True,
- rank=dist.get_rank(),
- num_replicas=dist.get_world_size())
- eval_sampler = DistributedSampler(eval_dataset,
- shuffle=True,
- seed=42,
- drop_last=True,
- rank=dist.get_rank(),
- num_replicas=dist.get_world_size())
+ train_sampler = DistributedSampler(
+ train_dataset,
+ shuffle=True,
+ seed=42,
+ drop_last=True,
+ rank=dist.get_rank(),
+ num_replicas=dist.get_world_size(),
+ )
+ eval_sampler = DistributedSampler(
+ eval_dataset,
+ shuffle=True,
+ seed=42,
+ drop_last=True,
+ rank=dist.get_rank(),
+ num_replicas=dist.get_world_size(),
+ )
else:
train_sampler = None
- valid_sampler = None
eval_sampler = None
- train_dataloader = DataLoader(train_dataset,
- shuffle=(train_sampler is None),
- sampler=train_sampler,
- batch_size=args.batch_size,
- pin_memory=True)
-
- valid_dataloader = DataLoader(valid_dataset,
- shuffle=(valid_sampler is None),
- sampler=valid_sampler,
- batch_size=args.batch_size,
- pin_memory=True)
-
- eval_dataloader = DataLoader(eval_dataset,
- shuffle=(eval_sampler is None),
- sampler=eval_sampler,
- batch_size=args.batch_size,
- pin_memory=True)
-
- (model, optim) = strategy.prepare((model, optim))
- trainer = RewardModelTrainer(model=model,
- strategy=strategy,
- optim=optim,
- loss_fn=loss_fn,
- train_dataloader=train_dataloader,
- valid_dataloader=valid_dataloader,
- eval_dataloader=eval_dataloader,
- max_epochs=args.max_epochs)
-
- trainer.fit()
+ train_dataloader = DataLoader(
+ train_dataset,
+ shuffle=(train_sampler is None),
+ sampler=train_sampler,
+ batch_size=args.batch_size,
+ pin_memory=True,
+ )
+
+ eval_dataloader = DataLoader(
+ eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True
+ )
+
+ lr_scheduler = CosineAnnealingLR(optim, train_dataloader.__len__() // 100)
+ strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
+ model = strategy_dict["model"]
+ optim = strategy_dict["optimizer"]
+ lr_scheduler = strategy_dict["lr_scheduler"]
+ trainer = RewardModelTrainer(
+ model=model,
+ strategy=strategy,
+ optim=optim,
+ lr_scheduler=lr_scheduler,
+ loss_fn=loss_fn,
+ max_epochs=args.max_epochs,
+ )
+
+ trainer.fit(
+ train_dataloader=train_dataloader,
+ eval_dataloader=eval_dataloader,
+ log_dir=args.log_dir,
+ use_wandb=args.use_wandb,
+ )
+
+ if args.lora_rank > 0 and args.merge_lora_weights:
+ from coati.models.lora import LORA_MANAGER
+
+ # NOTE: set model to eval to merge LoRA weights
+ LORA_MANAGER.merge_weights = True
+ model.eval()
# save model checkpoint after fitting on only rank0
- strategy.save_model(model, args.save_path, only_rank0=True)
+ state_dict = model.state_dict()
+ torch.save(state_dict, args.save_path)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
- strategy.save_optimizer(trainer.optimizer,
- 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
- only_rank0=False)
+ strategy.save_optimizer(
+ trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--strategy',
- choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
- default='colossalai_zero2')
- parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'deberta', 'llama', 'roberta'], default='bloom')
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--model_path', type=str, default=None)
- parser.add_argument('--need_optim_ckpt', type=bool, default=False)
- parser.add_argument('--dataset',
- type=str,
- choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'],
- default='Dahoas/rm-static')
- parser.add_argument('--subset', type=str, default=None)
- parser.add_argument('--save_path', type=str, default='rm_ckpt')
- parser.add_argument('--max_epochs', type=int, default=1)
- parser.add_argument('--batch_size', type=int, default=1)
- parser.add_argument('--max_len', type=int, default=512)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument('--loss_fn', type=str, default='log_sig', choices=['log_sig', 'log_exp'])
- parser.add_argument('--test', type=bool, default=False)
+ parser.add_argument(
+ "--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="colossalai_zero2"
+ )
+ parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom")
+ parser.add_argument("--tokenizer", type=str, default=None)
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--model_path", type=str, default=None)
+ parser.add_argument("--need_optim_ckpt", type=bool, default=False)
+ parser.add_argument(
+ "--dataset", type=str, choices=["Anthropic/hh-rlhf", "Dahoas/rm-static"], default="Dahoas/rm-static"
+ )
+ parser.add_argument("--subset", type=lambda x: None if x == "None" else x, default=None)
+ parser.add_argument("--max_datasets_size", type=int, default=1000000)
+ parser.add_argument("--save_path", type=str, default="rm_ckpt")
+ parser.add_argument("--max_epochs", type=int, default=1)
+ parser.add_argument("--batch_size", type=int, default=1)
+ parser.add_argument("--max_len", type=int, default=512)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+ parser.add_argument("--merge_lora_weights", type=bool, default=True)
+ parser.add_argument("--lr", type=float, default=9e-6)
+ parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"])
+ parser.add_argument("--log_dir", default="logs", type=str)
+ parser.add_argument("--use_wandb", default=False, action="store_true")
args = parser.parse_args()
train(args)
diff --git a/applications/Chat/examples/train_rm.sh b/applications/Chat/examples/train_rm.sh
index 80abe62d2a3fe9d70c0ab8be1b2e8e3b8afc5e03..c5ebaf708ddca9ddce2c39fda4e2cfc431ac80f9 100755
--- a/applications/Chat/examples/train_rm.sh
+++ b/applications/Chat/examples/train_rm.sh
@@ -1,13 +1,13 @@
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
- local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
- | tail -n +2 \
- | nl -v 0 \
- | tee /dev/tty \
- | sort -g -k 2 \
- | awk '{print $1}' \
- | head -n $n)
+ local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
+ tail -n +2 |
+ nl -v 0 |
+ tee /dev/tty |
+ sort -g -k 2 |
+ awk '{print $1}' |
+ head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
@@ -16,9 +16,10 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
set_n_least_used_CUDA_VISIBLE_DEVICES 2
torchrun --standalone --nproc_per_node=2 train_reward_model.py \
- --pretrain
\
- --model 'bloom' \
- --strategy colossalai_zero2 \
- --loss_fn 'log_sig'\
- --save_path \
- --dataset 'Anthropic/hh-rlhf'\
+ --pretrain 'gpt2' \
+ --model 'gpt2' \
+ --strategy colossalai_zero2 \
+ --loss_fn 'log_exp' \
+ --dataset 'Anthropic/hh-rlhf' \
+ --batch_size 16 \
+ --max_epochs 10
diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py
index da499f068b17885ac468ecd1dcb9de49100f667c..66d08da3012015cfd8438749fe7cc5b2927f4285 100644
--- a/applications/Chat/examples/train_sft.py
+++ b/applications/Chat/examples/train_sft.py
@@ -1,196 +1,221 @@
import argparse
-import os
+import math
+import warnings
-import loralib as lora
import torch
import torch.distributed as dist
-from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
-from coati.models import convert_to_lora_module
+from coati.dataset import SFTDataset, SupervisedDataset
+from coati.models.bloom import BLOOMActor
+from coati.models.chatglm import ChatGLMActor
+from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
+from coati.models.gpt import GPTActor
+from coati.models.llama import LlamaActor
+from coati.models.opt import OPTActor
from coati.trainer import SFTTrainer
-from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
-from coati.utils import prepare_llama_tokenizer_and_embedding
+from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from datasets import load_dataset
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
-from transformers import AutoTokenizer, BloomConfig, BloomForCausalLM, BloomTokenizerFast, LlamaConfig, LlamaForCausalLM
-from transformers.models.gpt2.configuration_gpt2 import GPT2Config
-from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
+from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
-from transformers.models.opt.configuration_opt import OPTConfig
-from transformers.models.opt.modeling_opt import OPTForCausalLM
+from transformers.trainer import get_scheduler
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import HybridAdam
-from colossalai.tensor import ColoParameter
def train(args):
# configure strategy
- if args.strategy == 'naive':
- strategy = NaiveStrategy()
- elif args.strategy == 'ddp':
+ if args.strategy == "ddp":
strategy = DDPStrategy()
- elif args.strategy == 'colossalai_gemini':
- raise NotImplementedError(
- 'Gemini is not supported .from_pretrained() yet. We will update this after checkpoint io is ready.')
- strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
- elif args.strategy == 'colossalai_zero2':
- strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
- elif args.strategy == 'colossalai_zero2_cpu':
- strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
+ elif args.strategy == "colossalai_gemini":
+ strategy = GeminiStrategy(placement_policy="auto")
+ elif args.strategy == "colossalai_zero2":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
+ elif args.strategy == "colossalai_zero2_cpu":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model
+ if args.lora_rank > 0:
+ warnings.warn("Lora is not supported yet.")
+ args.lora_rank = 0
+
with strategy.model_init_context():
- if args.model == 'bloom':
- model = convert_to_lora_module(BloomForCausalLM.from_pretrained(args.pretrain),
- args.lora_rank).half().cuda()
- elif args.model == 'opt':
- model = convert_to_lora_module(OPTForCausalLM.from_pretrained(args.pretrain), args.lora_rank).half().cuda()
- elif args.model == 'gpt2':
- model = convert_to_lora_module(GPT2LMHeadModel.from_pretrained(args.pretrain), args.lora_rank).half().cuda()
- elif args.model == 'llama':
- model = convert_to_lora_module(LlamaForCausalLM.from_pretrained(args.pretrain),
- args.lora_rank).half().cuda()
+ if args.model == "bloom":
+ model = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
+ elif args.model == "opt":
+ model = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
+ elif args.model == "gpt2":
+ model = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
+ elif args.model == "llama":
+ model = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
+ elif args.model == "chatglm":
+ model = ChatGLMActor(pretrained=args.pretrain)
else:
raise ValueError(f'Unsupported model "{args.model}"')
- if args.grad_checkpoint:
- model.gradient_checkpointing_enable()
+
+ model.to(torch.bfloat16).to(torch.cuda.current_device())
# configure tokenizer
- if args.model == 'gpt2':
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+ if args.model == "gpt2":
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer)
+ tokenizer.pad_token = tokenizer.eos_token
+ elif args.model == "bloom":
+ tokenizer = BloomTokenizerFast.from_pretrained(
+ "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer
+ )
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'bloom':
- tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
+ elif args.model == "opt":
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer)
tokenizer.pad_token = tokenizer.eos_token
- elif args.model == 'opt':
- tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
- elif args.model == 'llama':
- tokenizer = AutoTokenizer.from_pretrained(
- args.pretrain,
- padding_side="right",
- use_fast=False,
+ elif args.model == "llama":
+ tokenizer = LlamaTokenizer.from_pretrained(
+ "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
+ )
+ tokenizer.eos_token = "<\s>"
+ tokenizer.pad_token = tokenizer.unk_token
+ elif args.model == "chatglm":
+ tokenizer = ChatGLMTokenizer.from_pretrained(
+ "THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True
)
- tokenizer.eos_token = '<\s>'
else:
raise ValueError(f'Unsupported model "{args.model}"')
- tokenizer.pad_token = tokenizer.eos_token
- max_len = args.max_len
- if args.model == 'llama':
- tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model)
-
- if args.strategy == 'colossalai_gemini':
- # this is a hack to deal with the resized embedding
- # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatiblity
- for name, param in model.named_parameters():
- if not isinstance(param, ColoParameter):
- sub_module_name = '.'.join(name.split('.')[:-1])
- weight_name = name.split('.')[-1]
- sub_module = model.get_submodule(sub_module_name)
- setattr(sub_module, weight_name, ColoParameter(param))
- else:
- tokenizer.pad_token = tokenizer.eos_token
# configure optimizer
- if args.strategy.startswith('colossalai'):
+ if args.strategy.startswith("colossalai"):
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
else:
optim = Adam(model.parameters(), lr=args.lr)
- logger = get_dist_logger()
-
# configure dataset
- if args.dataset == 'yizhongw/self_instruct':
- train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train')
- eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test')
+ if args.dataset == "yizhongw/self_instruct":
+ train_data = load_dataset(args.dataset, "super_natural_instructions", split="train")
+ eval_data = load_dataset(args.dataset, "super_natural_instructions", split="test")
+
+ if args.max_datasets_size is not None:
+ train_data = train_data.select(range(min(args.max_datasets_size, len(train_data))))
+ eval_data = eval_data.select(range(min(args.max_datasets_size, len(eval_data))))
- train_dataset = SFTDataset(train_data, tokenizer, max_len)
- eval_dataset = SFTDataset(eval_data, tokenizer, max_len)
+ train_dataset = SFTDataset(train_data, tokenizer, args.max_len)
+ eval_dataset = SFTDataset(eval_data, tokenizer, args.max_len)
else:
- train_dataset = SupervisedDataset(tokenizer=tokenizer,
- data_path=args.dataset,
- max_datasets_size=args.max_datasets_size,
- max_length=max_len)
+ train_dataset = SupervisedDataset(
+ tokenizer=tokenizer,
+ data_path=args.dataset,
+ max_datasets_size=args.max_datasets_size,
+ max_length=args.max_len,
+ )
eval_dataset = None
- data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
if dist.is_initialized() and dist.get_world_size() > 1:
- train_sampler = DistributedSampler(train_dataset,
- shuffle=True,
- seed=42,
- drop_last=True,
- rank=dist.get_rank(),
- num_replicas=dist.get_world_size())
+ train_sampler = DistributedSampler(
+ train_dataset,
+ shuffle=True,
+ seed=42,
+ drop_last=True,
+ rank=dist.get_rank(),
+ num_replicas=dist.get_world_size(),
+ )
if eval_dataset is not None:
- eval_sampler = DistributedSampler(eval_dataset,
- shuffle=False,
- seed=42,
- drop_last=False,
- rank=dist.get_rank(),
- num_replicas=dist.get_world_size())
+ eval_sampler = DistributedSampler(
+ eval_dataset,
+ shuffle=False,
+ seed=42,
+ drop_last=False,
+ rank=dist.get_rank(),
+ num_replicas=dist.get_world_size(),
+ )
else:
train_sampler = None
eval_sampler = None
- train_dataloader = DataLoader(train_dataset,
- shuffle=(train_sampler is None),
- sampler=train_sampler,
- batch_size=args.batch_size,
- collate_fn=data_collator,
- pin_memory=True)
+ train_dataloader = DataLoader(
+ train_dataset,
+ shuffle=(train_sampler is None),
+ sampler=train_sampler,
+ batch_size=args.batch_size,
+ pin_memory=True,
+ )
if eval_dataset is not None:
- eval_dataloader = DataLoader(eval_dataset,
- shuffle=(eval_sampler is None),
- sampler=eval_sampler,
- batch_size=args.batch_size,
- collate_fn=data_collator,
- pin_memory=True)
+ eval_dataloader = DataLoader(
+ eval_dataset,
+ shuffle=(eval_sampler is None),
+ sampler=eval_sampler,
+ batch_size=args.batch_size,
+ pin_memory=True,
+ )
else:
eval_dataloader = None
- (model, optim) = strategy.prepare((model, optim))
- trainer = SFTTrainer(model=model,
- strategy=strategy,
- optim=optim,
- train_dataloader=train_dataloader,
- eval_dataloader=eval_dataloader,
- max_epochs=args.max_epochs,
- accumulation_steps=args.accumulation_steps)
-
- trainer.fit(logger=logger, use_wandb=args.use_wandb)
+ num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
+ max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch)
+ lr_scheduler = get_scheduler(
+ "cosine", optim, num_warmup_steps=math.ceil(max_steps * 0.03), num_training_steps=max_steps
+ )
+ strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
+ model = strategy_dict["model"]
+ optim = strategy_dict["optimizer"]
+ lr_scheduler = strategy_dict["lr_scheduler"]
+ trainer = SFTTrainer(
+ model=model,
+ strategy=strategy,
+ optim=optim,
+ lr_scheduler=lr_scheduler,
+ max_epochs=args.max_epochs,
+ accumulation_steps=args.accumulation_steps,
+ )
+ logger = get_dist_logger()
+ trainer.fit(
+ train_dataloader=train_dataloader,
+ eval_dataloader=eval_dataloader,
+ logger=logger,
+ log_dir=args.log_dir,
+ use_wandb=args.use_wandb,
+ )
+
+ if args.lora_rank > 0 and args.merge_lora_weights:
+ from coati.models.lora import LORA_MANAGER
+
+ # NOTE: set model to eval to merge LoRA weights
+ LORA_MANAGER.merge_weights = True
+ model.eval()
# save model checkpoint after fitting on only rank0
- strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
+ strategy.save_pretrained(model, path=args.save_path, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
- strategy.save_optimizer(trainer.optimizer,
- 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
- only_rank0=False)
+ strategy.save_optimizer(
+ trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--strategy',
- choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'],
- default='colossalai_zero2')
- parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
- parser.add_argument('--pretrain', type=str, default=None)
- parser.add_argument('--dataset', type=str, default=None)
- parser.add_argument('--max_datasets_size', type=int, default=None)
- parser.add_argument('--save_path', type=str, default='output')
- parser.add_argument('--need_optim_ckpt', type=bool, default=False)
- parser.add_argument('--max_epochs', type=int, default=3)
- parser.add_argument('--batch_size', type=int, default=4)
- parser.add_argument('--max_len', type=int, default=512)
- parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
- parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
- parser.add_argument('--lr', type=float, default=5e-6)
- parser.add_argument('--accumulation_steps', type=int, default=8)
- parser.add_argument('--use_wandb', default=False, action='store_true')
- parser.add_argument('--grad_checkpoint', default=False, action='store_true')
+ parser.add_argument(
+ "--strategy",
+ choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_zero2_cpu"],
+ default="colossalai_zero2",
+ )
+ parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama", "chatglm"], default="bloom")
+ parser.add_argument("--tokenizer", type=str, default=None)
+ parser.add_argument("--pretrain", type=str, default=None)
+ parser.add_argument("--dataset", type=str, default=None)
+ parser.add_argument("--max_datasets_size", type=int, default=None)
+ parser.add_argument("--save_path", type=str, default="output")
+ parser.add_argument("--need_optim_ckpt", type=bool, default=False)
+ parser.add_argument("--max_epochs", type=int, default=3)
+ parser.add_argument("--batch_size", type=int, default=4)
+ parser.add_argument("--max_len", type=int, default=512)
+ parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
+ parser.add_argument("--merge_lora_weights", type=bool, default=True)
+ parser.add_argument("--lr", type=float, default=5e-6)
+ parser.add_argument("--accumulation_steps", type=int, default=8)
+ parser.add_argument("--log_dir", default="logs", type=str)
+ parser.add_argument("--use_wandb", default=False, action="store_true")
+ parser.add_argument("--grad_checkpoint", default=False, action="store_true")
args = parser.parse_args()
train(args)
diff --git a/applications/Chat/examples/train_sft.sh b/applications/Chat/examples/train_sft.sh
index c880f85825a77a98ea49ce691bb5cf4fcabca857..0fb4da3d3ce8bc6091374ca98e8bbed0738f8c1d 100755
--- a/applications/Chat/examples/train_sft.sh
+++ b/applications/Chat/examples/train_sft.sh
@@ -1,12 +1,28 @@
+set_n_least_used_CUDA_VISIBLE_DEVICES() {
+ local n=${1:-"9999"}
+ echo "GPU Memory Usage:"
+ local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
+ tail -n +2 |
+ nl -v 0 |
+ tee /dev/tty |
+ sort -g -k 2 |
+ awk '{print $1}' |
+ head -n $n)
+ export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
+ echo "Now CUDA_VISIBLE_DEVICES is set to:"
+ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
+}
+
+set_n_least_used_CUDA_VISIBLE_DEVICES 4
+
torchrun --standalone --nproc_per_node=4 train_sft.py \
--pretrain "/path/to/LLaMa-7B/" \
--model 'llama' \
--strategy colossalai_zero2 \
- --log_interval 10 \
- --save_path /path/to/Coati-7B \
+ --save_path /path/to/Coati-7B \
--dataset /path/to/data.json \
--batch_size 4 \
--accumulation_steps 8 \
--lr 2e-5 \
--max_datasets_size 512 \
- --max_epochs 1 \
+ --max_epochs 1
diff --git a/applications/Chat/inference/README.md b/applications/Chat/inference/README.md
index 434677c98fa58f7050098671c6243c6f70d023a4..eea4ef5b86ca99106a813ca0a70f49bdf5b65c5e 100644
--- a/applications/Chat/inference/README.md
+++ b/applications/Chat/inference/README.md
@@ -20,21 +20,21 @@ Tha data is from [LLaMA Int8 4bit ChatBot Guide v2](https://rentry.org/llama-tar
### 8-bit
-| Model | Min GPU RAM | Recommended GPU RAM | Min RAM/Swap | Card examples |
-| :---: | :---: | :---: | :---: | :---: |
-| LLaMA-7B | 9.2GB | 10GB | 24GB | 3060 12GB, RTX 3080 10GB, RTX 3090 |
-| LLaMA-13B | 16.3GB | 20GB | 32GB | RTX 3090 Ti, RTX 4090 |
-| LLaMA-30B | 36GB | 40GB | 64GB | A6000 48GB, A100 40GB |
-| LLaMA-65B | 74GB | 80GB | 128GB | A100 80GB |
+| Model | Min GPU RAM | Recommended GPU RAM | Min RAM/Swap | Card examples |
+| :-------: | :---------: | :-----------------: | :----------: | :--------------------------------: |
+| LLaMA-7B | 9.2GB | 10GB | 24GB | 3060 12GB, RTX 3080 10GB, RTX 3090 |
+| LLaMA-13B | 16.3GB | 20GB | 32GB | RTX 3090 Ti, RTX 4090 |
+| LLaMA-30B | 36GB | 40GB | 64GB | A6000 48GB, A100 40GB |
+| LLaMA-65B | 74GB | 80GB | 128GB | A100 80GB |
### 4-bit
-| Model | Min GPU RAM | Recommended GPU RAM | Min RAM/Swap | Card examples |
-| :---: | :---: | :---: | :---: | :---: |
-| LLaMA-7B | 3.5GB | 6GB | 16GB | RTX 1660, 2060, AMD 5700xt, RTX 3050, 3060 |
-| LLaMA-13B | 6.5GB | 10GB | 32GB | AMD 6900xt, RTX 2060 12GB, 3060 12GB, 3080, A2000 |
-| LLaMA-30B | 15.8GB | 20GB | 64GB | RTX 3080 20GB, A4500, A5000, 3090, 4090, 6000, Tesla V100 |
-| LLaMA-65B | 31.2GB | 40GB | 128GB | A100 40GB, 2x3090, 2x4090, A40, RTX A6000, 8000, Titan Ada |
+| Model | Min GPU RAM | Recommended GPU RAM | Min RAM/Swap | Card examples |
+| :-------: | :---------: | :-----------------: | :----------: | :--------------------------------------------------------: |
+| LLaMA-7B | 3.5GB | 6GB | 16GB | RTX 1660, 2060, AMD 5700xt, RTX 3050, 3060 |
+| LLaMA-13B | 6.5GB | 10GB | 32GB | AMD 6900xt, RTX 2060 12GB, 3060 12GB, 3080, A2000 |
+| LLaMA-30B | 15.8GB | 20GB | 64GB | RTX 3080 20GB, A4500, A5000, 3090, 4090, 6000, Tesla V100 |
+| LLaMA-65B | 31.2GB | 40GB | 128GB | A100 40GB, 2x3090, 2x4090, A40, RTX A6000, 8000, Titan Ada |
## General setup
@@ -75,7 +75,7 @@ E.g. you can set `export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH`.
Please ensure you have downloaded HF-format model weights of LLaMA models first.
-Then you can follow [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa). This lib provides efficient CUDA kernels and weight convertion script.
+Then you can follow [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa). This lib provides efficient CUDA kernels and weight conversion script.
After installing this lib, we may convert the original HF-format LLaMA model weights to 4-bit version.
diff --git a/applications/Chat/inference/benchmark.py b/applications/Chat/inference/benchmark.py
index 59cd1eeea2aa841ae805d91adf279f668fdb3dd0..dbb5490a63dc2572957a02d39874376a1079f904 100644
--- a/applications/Chat/inference/benchmark.py
+++ b/applications/Chat/inference/benchmark.py
@@ -4,8 +4,8 @@ import argparse
from time import time
import torch
-from llama_gptq import load_quant
-from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
+from coati.quant import llama_load_quant, low_resource_init
+from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM
def generate_prompt(instruction, input=None):
@@ -84,49 +84,58 @@ inst = [instructions[0]] * 4
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
- 'pretrained',
- help='Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.')
- parser.add_argument('--quant',
- choices=['8bit', '4bit'],
- default=None,
- help='Quantization mode. Default: None (no quantization, fp16).')
+ "pretrained",
+ help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.",
+ )
+ parser.add_argument(
+ "--quant",
+ choices=["8bit", "4bit"],
+ default=None,
+ help="Quantization mode. Default: None (no quantization, fp16).",
+ )
parser.add_argument(
- '--gptq_checkpoint',
+ "--gptq_checkpoint",
default=None,
- help='Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.')
- parser.add_argument('--gptq_group_size',
- type=int,
- default=128,
- help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.')
+ help="Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.",
+ )
+ parser.add_argument(
+ "--gptq_group_size",
+ type=int,
+ default=128,
+ help="Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.",
+ )
args = parser.parse_args()
- if args.quant == '4bit':
- assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.'
+ if args.quant == "4bit":
+ assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint."
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
- if args.quant == '4bit':
- model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size)
+ if args.quant == "4bit":
+ with low_resource_init():
+ config = LlamaConfig.from_pretrained(args.pretrained)
+ model = LlamaForCausalLM(config)
+ model = llama_load_quant(model, args.gptq_checkpoint, 4, args.gptq_group_size)
model.cuda()
else:
model = LlamaForCausalLM.from_pretrained(
args.pretrained,
- load_in_8bit=(args.quant == '8bit'),
+ load_in_8bit=(args.quant == "8bit"),
torch_dtype=torch.float16,
device_map="auto",
)
- if args.quant != '8bit':
- model.half() # seems to fix bugs for some users.
+ if args.quant != "8bit":
+ model.half() # seems to fix bugs for some users.
model.eval()
total_tokens = 0
start = time()
for instruction in instructions:
print(f"Instruction: {instruction}")
- resp, tokens = evaluate(model, tokenizer, instruction, temparature=0.2, num_beams=1)
+ resp, tokens = evaluate(model, tokenizer, instruction, temperature=0.2, num_beams=1)
total_tokens += tokens
print(f"Response: {resp}")
- print('\n----------------------------\n')
+ print("\n----------------------------\n")
duration = time() - start
- print(f'Total time: {duration:.3f} s, {total_tokens/duration:.3f} tokens/s')
- print(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB')
+ print(f"Total time: {duration:.3f} s, {total_tokens/duration:.3f} tokens/s")
+ print(f"Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB")
diff --git a/applications/Chat/inference/llama_gptq/__init__.py b/applications/Chat/inference/llama_gptq/__init__.py
deleted file mode 100644
index 51c8d6316290fe2fcef7d972803017c830d3e1b4..0000000000000000000000000000000000000000
--- a/applications/Chat/inference/llama_gptq/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from .loader import load_quant
-
-__all__ = [
- 'load_quant',
-]
diff --git a/applications/Chat/inference/llama_gptq/loader.py b/applications/Chat/inference/llama_gptq/loader.py
deleted file mode 100644
index a5c6ac7d1589aa1873918b9c8b02edcfe13ed59f..0000000000000000000000000000000000000000
--- a/applications/Chat/inference/llama_gptq/loader.py
+++ /dev/null
@@ -1,41 +0,0 @@
-import torch
-import torch.nn as nn
-import transformers
-from transformers import LlamaConfig, LlamaForCausalLM
-
-from .model_utils import find_layers
-from .quant import make_quant
-
-
-def load_quant(pretrained: str, checkpoint: str, wbits: int, groupsize: int):
- config = LlamaConfig.from_pretrained(pretrained)
-
- def noop(*args, **kwargs):
- pass
-
- torch.nn.init.kaiming_uniform_ = noop
- torch.nn.init.uniform_ = noop
- torch.nn.init.normal_ = noop
-
- torch.set_default_dtype(torch.half)
- transformers.modeling_utils._init_weights = False
- torch.set_default_dtype(torch.half)
- model = LlamaForCausalLM(config)
- torch.set_default_dtype(torch.float)
- model = model.eval()
- layers = find_layers(model)
- for name in ['lm_head']:
- if name in layers:
- del layers[name]
- make_quant(model, layers, wbits, groupsize)
-
- print(f'Loading model with {wbits} bits...')
- if checkpoint.endswith('.safetensors'):
- from safetensors.torch import load_file as safe_load
- model.load_state_dict(safe_load(checkpoint))
- else:
- model.load_state_dict(torch.load(checkpoint))
- model.seqlen = 2048
- print('Done.')
-
- return model
diff --git a/applications/Chat/inference/llama_gptq/model_utils.py b/applications/Chat/inference/llama_gptq/model_utils.py
deleted file mode 100644
index 62db171abb52cb88799a8b73d608f2617208cefe..0000000000000000000000000000000000000000
--- a/applications/Chat/inference/llama_gptq/model_utils.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/modelutils.py
-
-import torch
-import torch.nn as nn
-
-
-def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
- if type(module) in layers:
- return {name: module}
- res = {}
- for name1, child in module.named_children():
- res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
- return res
diff --git a/applications/Chat/inference/locustfile.py b/applications/Chat/inference/locustfile.py
index 51cdc68125bba42d29a91f03285847e0bde27ea8..333262e538ac46d75734f14e3b648b65c3a2ff3d 100644
--- a/applications/Chat/inference/locustfile.py
+++ b/applications/Chat/inference/locustfile.py
@@ -1,27 +1,26 @@
-from json import JSONDecodeError
-
from locust import HttpUser, task
-samples = [[
- dict(
- instruction='Who is the best player in the history of NBA?',
- response=
- 'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
- ),
- dict(instruction='continue this talk', response=''),
-], [
- dict(instruction='Who is the best player in the history of NBA?', response=''),
-]]
+samples = [
+ [
+ dict(
+ instruction="Who is the best player in the history of NBA?",
+ response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
+ ),
+ dict(instruction="continue this talk", response=""),
+ ],
+ [
+ dict(instruction="Who is the best player in the history of NBA?", response=""),
+ ],
+]
class GenerationUser(HttpUser):
-
@task
def generate(self):
for sample in samples:
- data = {'max_new_tokens': 64, 'history': sample}
- with self.client.post('/generate', json=data, catch_response=True) as response:
+ data = {"max_new_tokens": 64, "history": sample}
+ with self.client.post("/generate", json=data, catch_response=True) as response:
if response.status_code in (200, 406):
response.success()
else:
- response.failure('Response wrong')
+ response.failure("Response wrong")
diff --git a/applications/Chat/inference/requirements.txt b/applications/Chat/inference/requirements.txt
index 511fe1a4f1f339b23f1a11162925703b94e15013..cb6275361736a730cb53a95d3d4090407fd3f6c4 100644
--- a/applications/Chat/inference/requirements.txt
+++ b/applications/Chat/inference/requirements.txt
@@ -10,4 +10,4 @@ uvicorn
git+https://github.com/huggingface/transformers
accelerate
bitsandbytes
-jieba
\ No newline at end of file
+jieba
diff --git a/applications/Chat/inference/server.py b/applications/Chat/inference/server.py
index b4627299397e6949576318de3938ee9c19aa390c..7c6a61b9e7f2c37caafc278c8f7923d3ca7429e1 100644
--- a/applications/Chat/inference/server.py
+++ b/applications/Chat/inference/server.py
@@ -1,22 +1,22 @@
import argparse
import os
from threading import Lock
-from typing import Dict, Generator, List, Optional
+from typing import Generator, List, Optional
import torch
import uvicorn
-from fastapi import FastAPI, HTTPException, Request
+from coati.quant import llama_load_quant, low_resource_init
+from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
-from llama_gptq import load_quant
from pydantic import BaseModel, Field
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from sse_starlette.sse import EventSourceResponse
-from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
-from utils import ChatPromptProcessor, Dialogue, LockedIterator, sample_streamingly, update_model_kwargs_fn, load_json
+from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
+from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, sample_streamingly, update_model_kwargs_fn
-CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
+CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions."
MAX_LEN = 512
running_lock = Lock()
@@ -36,11 +36,11 @@ app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# set CORS
-origin_spec_from_env = os.environ.get('CORS_ORIGIN', None)
+origin_spec_from_env = os.environ.get("CORS_ORIGIN", None)
if origin_spec_from_env is not None:
# allow CORS from the specified origins
- origins = os.environ['CORS_ORIGIN'].split(',')
+ origins = os.environ["CORS_ORIGIN"].split(",")
else:
# allow CORS from all origins
origins = ["*"]
@@ -56,15 +56,15 @@ app.add_middleware(
def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature):
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
- #TODO(ver217): streaming generation does not support repetition_penalty now
+ # TODO(ver217): streaming generation does not support repetition_penalty now
model_kwargs = {
- 'max_generate_tokens': max_new_tokens,
- 'early_stopping': True,
- 'top_k': top_k,
- 'top_p': top_p,
- 'temperature': temperature,
- 'prepare_inputs_fn': model.prepare_inputs_for_generation,
- 'update_model_kwargs_fn': update_model_kwargs_fn,
+ "max_generate_tokens": max_new_tokens,
+ "early_stopping": True,
+ "top_k": top_k,
+ "top_p": top_p,
+ "temperature": temperature,
+ "prepare_inputs_fn": model.prepare_inputs_for_generation,
+ "update_model_kwargs_fn": update_model_kwargs_fn,
}
is_first_word = True
generator = LockedIterator(sample_streamingly(model, **inputs, **model_kwargs), running_lock)
@@ -81,9 +81,9 @@ def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature):
if is_first_word:
out_string = out_string.lstrip()
is_first_word = False
- elif current_sub_tokens[0].startswith('▁'):
+ elif current_sub_tokens[0].startswith("▁"):
# whitespace will be ignored by the frontend
- out_string = ' ' + out_string
+ out_string = " " + out_string
yield out_string
@@ -92,32 +92,33 @@ async def event_generator(request: Request, generator: Generator):
if await request.is_disconnected():
break
try:
- yield {'event': 'generate', 'data': next(generator)}
+ yield {"event": "generate", "data": next(generator)}
except StopIteration:
- yield {'event': 'end', 'data': ''}
+ yield {"event": "end", "data": ""}
break
-@app.post('/generate/stream')
-@limiter.limit('1/second')
+@app.post("/generate/stream")
+@limiter.limit("1/second")
def generate(data: GenerationTaskReq, request: Request):
prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
event_source = event_generator(
- request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature))
+ request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature)
+ )
return EventSourceResponse(event_source)
-@app.post('/generate')
-@limiter.limit('1/second')
+@app.post("/generate")
+@limiter.limit("1/second")
def generate_no_stream(data: GenerationTaskReq, request: Request):
prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
if prompt_processor.has_censored_words(prompt):
return prompt_processor.SAFE_RESPONSE
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
with running_lock:
- output = model.generate(**inputs, **data.dict(exclude={'history'}))
+ output = model.generate(**inputs, **data.dict(exclude={"history"}))
output = output.cpu()
- prompt_len = inputs['input_ids'].size(1)
+ prompt_len = inputs["input_ids"].size(1)
response = output[0, prompt_len:]
out_string = tokenizer.decode(response, skip_special_tokens=True)
out_string = prompt_processor.postprocess_output(out_string)
@@ -126,30 +127,40 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
return out_string
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
- 'pretrained',
- help='Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.')
- parser.add_argument('--quant',
- choices=['8bit', '4bit'],
- default=None,
- help='Quantization mode. Default: None (no quantization, fp16).')
+ "pretrained",
+ help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.",
+ )
parser.add_argument(
- '--gptq_checkpoint',
+ "--quant",
+ choices=["8bit", "4bit"],
default=None,
- help='Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.')
- parser.add_argument('--gptq_group_size',
- type=int,
- default=128,
- help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.')
- parser.add_argument('--http_host', default='0.0.0.0')
- parser.add_argument('--http_port', type=int, default=7070)
- parser.add_argument('--profanity_file', default=None, help='Path to profanity words list. It should be a JSON file containing a list of words.')
+ help="Quantization mode. Default: None (no quantization, fp16).",
+ )
+ parser.add_argument(
+ "--gptq_checkpoint",
+ default=None,
+ help="Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.",
+ )
+ parser.add_argument(
+ "--gptq_group_size",
+ type=int,
+ default=128,
+ help="Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.",
+ )
+ parser.add_argument("--http_host", default="0.0.0.0")
+ parser.add_argument("--http_port", type=int, default=7070)
+ parser.add_argument(
+ "--profanity_file",
+ default=None,
+ help="Path to profanity words list. It should be a JSON file containing a list of words.",
+ )
args = parser.parse_args()
- if args.quant == '4bit':
- assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.'
+ if args.quant == "4bit":
+ assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint."
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
@@ -159,18 +170,21 @@ if __name__ == '__main__':
censored_words = []
prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words)
- if args.quant == '4bit':
- model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size)
+ if args.quant == "4bit":
+ with low_resource_init():
+ config = LlamaConfig.from_pretrained(args.pretrained)
+ model = LlamaForCausalLM(config)
+ model = llama_load_quant(model, args.gptq_checkpoint, 4, args.gptq_group_size)
model.cuda()
else:
model = LlamaForCausalLM.from_pretrained(
args.pretrained,
- load_in_8bit=(args.quant == '8bit'),
+ load_in_8bit=(args.quant == "8bit"),
torch_dtype=torch.float16,
device_map="auto",
)
- if args.quant != '8bit':
- model.half() # seems to fix bugs for some users.
+ if args.quant != "8bit":
+ model.half() # seems to fix bugs for some users.
model.eval()
config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
diff --git a/applications/Chat/inference/tests/test_chat_prompt.py b/applications/Chat/inference/tests/test_chat_prompt.py
index f5737ebe8c097d73073bb21195341b378e7fc2f1..9835e71894c66fdeffd0bb625259da936af4d776 100644
--- a/applications/Chat/inference/tests/test_chat_prompt.py
+++ b/applications/Chat/inference/tests/test_chat_prompt.py
@@ -3,44 +3,49 @@ import os
from transformers import AutoTokenizer
from utils import ChatPromptProcessor, Dialogue
-CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
-tokenizer = AutoTokenizer.from_pretrained(os.environ['PRETRAINED_PATH'])
+CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions."
+tokenizer = AutoTokenizer.from_pretrained(os.environ["PRETRAINED_PATH"])
samples = [
- ([
- Dialogue(
- instruction='Who is the best player in the history of NBA?',
- response=
- 'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
- ),
- Dialogue(instruction='continue this talk', response=''),
- ], 128,
- 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n'
+ (
+ [
+ Dialogue(
+ instruction="Who is the best player in the history of NBA?",
+ response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
+ ),
+ Dialogue(instruction="continue this talk", response=""),
+ ],
+ 128,
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n",
),
- ([
- Dialogue(
- instruction='Who is the best player in the history of NBA?',
- response=
- 'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
- ),
- Dialogue(instruction='continue this talk', response=''),
- ], 200,
- 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n'
+ (
+ [
+ Dialogue(
+ instruction="Who is the best player in the history of NBA?",
+ response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
+ ),
+ Dialogue(instruction="continue this talk", response=""),
+ ],
+ 200,
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n",
),
- ([
- Dialogue(
- instruction='Who is the best player in the history of NBA?',
- response=
- 'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
- ),
- Dialogue(instruction='continue this talk', response=''),
- ], 211,
- 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n'
+ (
+ [
+ Dialogue(
+ instruction="Who is the best player in the history of NBA?",
+ response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1",
+ ),
+ Dialogue(instruction="continue this talk", response=""),
+ ],
+ 211,
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n",
),
- ([
- Dialogue(instruction='Who is the best player in the history of NBA?', response=''),
- ], 128,
- 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n'
+ (
+ [
+ Dialogue(instruction="Who is the best player in the history of NBA?", response=""),
+ ],
+ 128,
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n",
),
]
@@ -52,5 +57,5 @@ def test_chat_prompt_processor():
assert prompt == result
-if __name__ == '__main__':
+if __name__ == "__main__":
test_chat_prompt_processor()
diff --git a/applications/Chat/inference/utils.py b/applications/Chat/inference/utils.py
index 37944be70a3bf9631f0a995bd1753d71d6c8b5aa..af018adf6e9de84c846281c9aa26a5459da14bf1 100644
--- a/applications/Chat/inference/utils.py
+++ b/applications/Chat/inference/utils.py
@@ -1,9 +1,9 @@
+import json
import re
from threading import Lock
from typing import Any, Callable, Generator, List, Optional
-import json
-import jieba
+import jieba
import torch
import torch.distributed as dist
import torch.nn as nn
@@ -20,9 +20,9 @@ except ImportError:
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
-def prepare_logits_processor(top_k: Optional[int] = None,
- top_p: Optional[float] = None,
- temperature: Optional[float] = None) -> LogitsProcessorList:
+def prepare_logits_processor(
+ top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
+) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
if temperature is not None and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
@@ -41,29 +41,30 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
return unfinished_sequences.max() == 0
-def sample_streamingly(model: nn.Module,
- input_ids: torch.Tensor,
- max_generate_tokens: int,
- early_stopping: bool = False,
- eos_token_id: Optional[int] = None,
- pad_token_id: Optional[int] = None,
- top_k: Optional[int] = None,
- top_p: Optional[float] = None,
- temperature: Optional[float] = None,
- prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
- update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
- **model_kwargs) -> Generator:
-
+def sample_streamingly(
+ model: nn.Module,
+ input_ids: torch.Tensor,
+ max_generate_tokens: int,
+ early_stopping: bool = False,
+ eos_token_id: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ temperature: Optional[float] = None,
+ prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
+ update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
+ **model_kwargs,
+) -> Generator:
logits_processor = prepare_logits_processor(top_k, top_p, temperature)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
for _ in range(max_generate_tokens):
- model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {
- 'input_ids': input_ids
- }
+ model_inputs = (
+ prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids}
+ )
outputs = model(**model_inputs)
- next_token_logits = outputs['logits'][:, -1, :]
+ next_token_logits = outputs["logits"][:, -1, :]
# pre-process distribution
next_token_logits = logits_processor(input_ids, next_token_logits)
# sample
@@ -107,27 +108,28 @@ def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict:
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
- [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
+ )
return model_kwargs
class Dialogue(BaseModel):
- instruction: str = Field(min_length=1, example='Count up from 1 to 500.')
- response: str = Field(example='')
+ instruction: str = Field(min_length=1, example="Count up from 1 to 500.")
+ response: str = Field(example="")
-def _format_dialogue(instruction: str, response: str = ''):
- return f'\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}'
+def _format_dialogue(instruction: str, response: str = ""):
+ return f"\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}"
-STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S))
+STOP_PAT = re.compile(r"(###|instruction:).*", flags=(re.I | re.S))
class ChatPromptProcessor:
- SAFE_RESPONSE = 'The input/response contains inappropriate content, please rephrase your prompt.'
+ SAFE_RESPONSE = "The input/response contains inappropriate content, please rephrase your prompt."
- def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str]=[]):
+ def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str] = []):
self.tokenizer = tokenizer
self.context = context
self.max_len = max_len
@@ -138,42 +140,48 @@ class ChatPromptProcessor:
def preprocess_prompt(self, history: List[Dialogue], max_new_tokens: int) -> str:
if self.context_len is None:
- self.context_len = len(self.tokenizer(self.context)['input_ids'])
+ self.context_len = len(self.tokenizer(self.context)["input_ids"])
if self.dialogue_placeholder_len is None:
self.dialogue_placeholder_len = len(
- self.tokenizer(_format_dialogue(''), add_special_tokens=False)['input_ids'])
+ self.tokenizer(_format_dialogue(""), add_special_tokens=False)["input_ids"]
+ )
prompt = self.context
# the last dialogue must be in the prompt
last_dialogue = history.pop()
# the response of the last dialogue is empty
- assert last_dialogue.response == ''
- if len(self.tokenizer(_format_dialogue(last_dialogue.instruction), add_special_tokens=False)
- ['input_ids']) + max_new_tokens + self.context_len >= self.max_len:
+ assert last_dialogue.response == ""
+ if (
+ len(self.tokenizer(_format_dialogue(last_dialogue.instruction), add_special_tokens=False)["input_ids"])
+ + max_new_tokens
+ + self.context_len
+ >= self.max_len
+ ):
# to avoid truncate placeholder, apply truncate to the original instruction
- instruction_truncated = self.tokenizer(last_dialogue.instruction,
- add_special_tokens=False,
- truncation=True,
- max_length=(self.max_len - max_new_tokens - self.context_len -
- self.dialogue_placeholder_len))['input_ids']
+ instruction_truncated = self.tokenizer(
+ last_dialogue.instruction,
+ add_special_tokens=False,
+ truncation=True,
+ max_length=(self.max_len - max_new_tokens - self.context_len - self.dialogue_placeholder_len),
+ )["input_ids"]
instruction_truncated = self.tokenizer.decode(instruction_truncated).lstrip()
prompt += _format_dialogue(instruction_truncated)
return prompt
- res_len = self.max_len - max_new_tokens - len(self.tokenizer(prompt)['input_ids'])
+ res_len = self.max_len - max_new_tokens - len(self.tokenizer(prompt)["input_ids"])
rows = []
for dialogue in history[::-1]:
text = _format_dialogue(dialogue.instruction, dialogue.response)
- cur_len = len(self.tokenizer(text, add_special_tokens=False)['input_ids'])
+ cur_len = len(self.tokenizer(text, add_special_tokens=False)["input_ids"])
if res_len - cur_len < 0:
break
res_len -= cur_len
rows.insert(0, text)
- prompt += ''.join(rows) + _format_dialogue(last_dialogue.instruction)
+ prompt += "".join(rows) + _format_dialogue(last_dialogue.instruction)
return prompt
def postprocess_output(self, output: str) -> str:
- output = STOP_PAT.sub('', output)
+ output = STOP_PAT.sub("", output)
return output.strip()
def has_censored_words(self, text: str) -> bool:
@@ -182,8 +190,8 @@ class ChatPromptProcessor:
intersection = set(jieba.cut(text.lower())) & self.censored_words
return len(intersection) > 0
-class LockedIterator:
+class LockedIterator:
def __init__(self, it, lock: Lock) -> None:
self.lock = lock
self.it = iter(it)
@@ -195,6 +203,7 @@ class LockedIterator:
with self.lock:
return next(self.it)
+
def load_json(path: str):
with open(path) as f:
- return json.load(f)
\ No newline at end of file
+ return json.load(f)
diff --git a/applications/Chat/requirements-test.txt b/applications/Chat/requirements-test.txt
index e079f8a6038dd2dc8512967540f96ee0de172067..93d48bcb6f7928b9cd71944be03933c95ddb26a5 100644
--- a/applications/Chat/requirements-test.txt
+++ b/applications/Chat/requirements-test.txt
@@ -1 +1,2 @@
pytest
+colossalai==0.3.3
diff --git a/applications/Chat/requirements.txt b/applications/Chat/requirements.txt
index af7ff67861eb73573489e9ae46f1a2d29eaa02b3..e56aaca0e7cb4344f52e9e1b43cc89110472b85c 100644
--- a/applications/Chat/requirements.txt
+++ b/applications/Chat/requirements.txt
@@ -2,7 +2,7 @@ transformers>=4.20.1
tqdm
datasets
loralib
-colossalai>=0.2.4
+colossalai==0.3.3
torch<2.0.0, >=1.12.1
langchain
tokenizers
@@ -11,3 +11,4 @@ sse_starlette
wandb
sentencepiece
gpustat
+tensorboard
diff --git a/applications/Chat/setup.py b/applications/Chat/setup.py
index a285a6dff4bf9cfe6905494de83a0baacfd795cb..eb44b6203ef8177798830c0e51799163ee8146a8 100644
--- a/applications/Chat/setup.py
+++ b/applications/Chat/setup.py
@@ -2,40 +2,42 @@ from setuptools import find_packages, setup
def fetch_requirements(path):
- with open(path, 'r') as fd:
+ with open(path, "r") as fd:
return [r.strip() for r in fd.readlines()]
def fetch_readme():
- with open('README.md', encoding='utf-8') as f:
+ with open("README.md", encoding="utf-8") as f:
return f.read()
def fetch_version():
- with open('version.txt', 'r') as f:
+ with open("version.txt", "r") as f:
return f.read().strip()
setup(
- name='coati',
+ name="coati",
version=fetch_version(),
- packages=find_packages(exclude=(
- 'tests',
- 'benchmarks',
- '*.egg-info',
- )),
- description='Colossal-AI Talking Intelligence',
+ packages=find_packages(
+ exclude=(
+ "tests",
+ "benchmarks",
+ "*.egg-info",
+ )
+ ),
+ description="Colossal-AI Talking Intelligence",
long_description=fetch_readme(),
- long_description_content_type='text/markdown',
- license='Apache Software License 2.0',
- url='https://github.com/hpcaitech/Coati',
- install_requires=fetch_requirements('requirements.txt'),
- python_requires='>=3.6',
+ long_description_content_type="text/markdown",
+ license="Apache Software License 2.0",
+ url="https://github.com/hpcaitech/Coati",
+ install_requires=fetch_requirements("requirements.txt"),
+ python_requires=">=3.6",
classifiers=[
- 'Programming Language :: Python :: 3',
- 'License :: OSI Approved :: Apache Software License',
- 'Environment :: GPU :: NVIDIA CUDA',
- 'Topic :: Scientific/Engineering :: Artificial Intelligence',
- 'Topic :: System :: Distributed Computing',
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: Apache Software License",
+ "Environment :: GPU :: NVIDIA CUDA",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Topic :: System :: Distributed Computing",
],
)
diff --git a/applications/Chat/tests/test_benchmarks.sh b/applications/Chat/tests/test_benchmarks.sh
new file mode 100755
index 0000000000000000000000000000000000000000..3fdb2518134231e2da1a8f4bbc4e5906b43122ef
--- /dev/null
+++ b/applications/Chat/tests/test_benchmarks.sh
@@ -0,0 +1,33 @@
+#!/bin/bash
+
+set -xue
+
+echo "Hint: You can run this script with 'verbose' as the first argument to run all strategies."
+
+if [[ $# -ne 0 && "$1" == "verbose" ]]; then
+ STRATEGIES=(
+ 'ddp'
+ 'colossalai_gemini'
+ 'colossalai_gemini_cpu'
+ 'colossalai_zero2'
+ 'colossalai_zero2_cpu'
+ 'colossalai_zero1'
+ 'colossalai_zero1_cpu'
+ )
+else
+ STRATEGIES=(
+ 'colossalai_zero2'
+ )
+fi
+
+BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
+BENCHMARKS_DIR=$BASE_DIR/benchmarks
+
+echo "[Test]: testing benchmarks ..."
+
+for strategy in ${STRATEGIES[@]}; do
+ torchrun --standalone --nproc_per_node 1 $BENCHMARKS_DIR/benchmark_opt_lora_dummy.py \
+ --model 125m --critic_model 125m --strategy ${strategy} --lora_rank 4 \
+ --num_episodes 2 --num_collect_steps 4 --num_update_steps 2 \
+ --train_batch_size 2 --experience_batch_size 4
+done
diff --git a/applications/Chat/tests/test_checkpoint.py b/applications/Chat/tests/test_checkpoint.py
index 4c05a343169905dc03d91b1f4c39530ec8834848..9c08aa36c9b40eba008ed6c3fb0d1bb6fdef0a4c 100644
--- a/applications/Chat/tests/test_checkpoint.py
+++ b/applications/Chat/tests/test_checkpoint.py
@@ -6,7 +6,8 @@ import pytest
import torch
import torch.distributed as dist
from coati.models.gpt import GPTActor
-from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy
+from coati.models.utils import calc_action_log_probs
+from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from colossalai.nn.optimizer import HybridAdam
@@ -16,39 +17,37 @@ GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
def get_data(batch_size: int, seq_len: int = 10) -> dict:
- input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda')
+ input_ids = torch.randint(0, 50257, (batch_size, seq_len), device="cuda")
attention_mask = torch.ones_like(input_ids)
return dict(input_ids=input_ids, attention_mask=attention_mask)
-def run_test_checkpoint(strategy):
- BATCH_SIZE = 2
+def train_step(strategy: Strategy, actor: GPTActor, actor_optim: HybridAdam, batch_size: int = 8):
+ data = get_data(batch_size)
+ action_mask = torch.ones_like(data["attention_mask"], dtype=torch.bool)
+ actor_logits = actor(data["input_ids"], data["attention_mask"])["logits"]
+ action_log_probs = calc_action_log_probs(actor_logits, data["input_ids"], action_mask.size(1))
+ loss = action_log_probs.sum()
+ strategy.backward(loss, actor, actor_optim)
+ strategy.optimizer_step(actor_optim)
- if strategy == 'ddp':
+
+def run_test_checkpoint(strategy_name: str, shard: bool):
+ if strategy_name == "ddp":
strategy = DDPStrategy()
- elif strategy == 'colossalai_gemini':
- strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
- elif strategy == 'colossalai_zero2':
- strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
+ elif strategy_name == "colossalai_gemini":
+ strategy = GeminiStrategy(placement_policy="auto", initial_scale=2**5)
+ elif strategy_name == "colossalai_zero2":
+ strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
- raise ValueError(f'Unsupported strategy "{strategy}"')
+ raise ValueError(f"Unsupported strategy '{strategy_name}'")
with strategy.model_init_context():
actor = GPTActor(config=GPT_CONFIG).cuda()
-
actor_optim = HybridAdam(actor.parameters())
-
actor, actor_optim = strategy.prepare((actor, actor_optim))
- def run_step():
- data = get_data(BATCH_SIZE)
- action_mask = torch.ones_like(data['attention_mask'], dtype=torch.bool)
- action_log_probs = actor(data['input_ids'], action_mask.size(1), data['attention_mask'])
- loss = action_log_probs.sum()
- strategy.backward(loss, actor, actor_optim)
- strategy.optimizer_step(actor_optim)
-
- run_step()
+ train_step(strategy, actor, actor_optim)
ctx = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext()
@@ -57,38 +56,36 @@ def run_test_checkpoint(strategy):
dist.broadcast_object_list(rank0_dirname)
rank0_dirname = rank0_dirname[0]
- model_path = os.path.join(rank0_dirname, 'model.pt')
- optim_path = os.path.join(rank0_dirname, f'optim-r{dist.get_rank()}.pt')
-
- strategy.save_model(actor, model_path, only_rank0=True)
- strategy.save_optimizer(actor_optim, optim_path, only_rank0=False)
-
+ model_path = os.path.join(rank0_dirname, "model" if shard else f"model.pt")
+ strategy.save_model(actor, model_path)
+ optim_path = os.path.join(rank0_dirname, "optim" if shard else "optim.pt")
+ strategy.save_optimizer(actor_optim, optim_path)
dist.barrier()
strategy.load_model(actor, model_path, strict=False)
strategy.load_optimizer(actor_optim, optim_path)
-
dist.barrier()
- run_step()
+ train_step(strategy, actor, actor_optim)
-def run_dist(rank, world_size, port, strategy):
- os.environ['RANK'] = str(rank)
- os.environ['LOCAL_RANK'] = str(rank)
- os.environ['WORLD_SIZE'] = str(world_size)
- os.environ['MASTER_ADDR'] = 'localhost'
- os.environ['MASTER_PORT'] = str(port)
- run_test_checkpoint(strategy)
+def run_dist(rank: int, world_size: int, port: int, strategy_name: str, shard: bool):
+ os.environ["RANK"] = str(rank)
+ os.environ["LOCAL_RANK"] = str(rank)
+ os.environ["WORLD_SIZE"] = str(world_size)
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = str(port)
+ run_test_checkpoint(strategy_name, shard)
@pytest.mark.dist
-@pytest.mark.parametrize('world_size', [2])
-@pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini'])
+@pytest.mark.parametrize("world_size", [4])
+@pytest.mark.parametrize("strategy_name", ["ddp", "colossalai_gemini", "colossalai_zero2"])
+@pytest.mark.parametrize("shard", [False, True])
@rerun_if_address_is_in_use()
-def test_checkpoint(world_size, strategy):
- spawn(run_dist, world_size, strategy=strategy)
+def test_checkpoint(world_size: int, strategy_name: str, shard: bool):
+ spawn(run_dist, world_size, strategy_name=strategy_name, shard=shard)
-if __name__ == '__main__':
- test_checkpoint(2, 'colossalai_zero2')
+if __name__ == "__main__":
+ test_checkpoint(2, "colossalai_gemini", shard=False)
diff --git a/applications/Chat/tests/test_data.py b/applications/Chat/tests/test_data.py
deleted file mode 100644
index 2e4d4ceac05fa603b98e4ad1c9b098e221345e83..0000000000000000000000000000000000000000
--- a/applications/Chat/tests/test_data.py
+++ /dev/null
@@ -1,118 +0,0 @@
-import os
-from copy import deepcopy
-
-import pytest
-import torch
-import torch.distributed as dist
-from coati.experience_maker import NaiveExperienceMaker
-from coati.models.base import RewardModel
-from coati.models.gpt import GPTActor, GPTCritic
-from coati.replay_buffer import NaiveReplayBuffer
-from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy
-from transformers.models.gpt2.configuration_gpt2 import GPT2Config
-
-from colossalai.testing import rerun_if_address_is_in_use, spawn
-
-GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
-
-
-def get_data(batch_size: int, seq_len: int = 10) -> dict:
- input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda')
- attention_mask = torch.ones_like(input_ids)
- return dict(input_ids=input_ids, attention_mask=attention_mask)
-
-
-def gather_and_equal(tensor: torch.Tensor) -> bool:
- world_size = dist.get_world_size()
- outputs = [torch.empty_like(tensor) for _ in range(world_size)]
- dist.all_gather(outputs, tensor.contiguous())
- for t in outputs[1:]:
- if not torch.equal(outputs[0], t):
- return False
- return True
-
-
-def run_test_data(strategy):
- EXPERINCE_BATCH_SIZE = 4
- SAMPLE_BATCH_SIZE = 2
-
- if strategy == 'ddp':
- strategy = DDPStrategy()
- elif strategy == 'colossalai':
- strategy = ColossalAIStrategy(placement_policy='cuda')
- else:
- raise ValueError(f'Unsupported strategy "{strategy}"')
-
- actor = GPTActor(config=GPT_CONFIG).cuda()
- critic = GPTCritic(config=GPT_CONFIG).cuda()
-
- initial_model = deepcopy(actor)
- reward_model = RewardModel(deepcopy(critic.model)).cuda()
-
- experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model)
- replay_buffer = NaiveReplayBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False)
-
- # experience of all ranks should be the same
- for _ in range(2):
- data = get_data(EXPERINCE_BATCH_SIZE)
- assert gather_and_equal(data['input_ids'])
- assert gather_and_equal(data['attention_mask'])
- experience = experience_maker.make_experience(**data,
- do_sample=True,
- max_length=16,
- eos_token_id=50256,
- pad_token_id=50256)
- assert gather_and_equal(experience.sequences)
- assert gather_and_equal(experience.action_log_probs)
- assert gather_and_equal(experience.values)
- assert gather_and_equal(experience.reward)
- assert gather_and_equal(experience.advantages)
- assert gather_and_equal(experience.action_mask)
- assert gather_and_equal(experience.attention_mask)
- replay_buffer.append(experience)
-
- # replay buffer's data should be the same
- buffer_size = torch.tensor([len(replay_buffer)], device='cuda')
- assert gather_and_equal(buffer_size)
- for item in replay_buffer.items:
- assert gather_and_equal(item.sequences)
- assert gather_and_equal(item.action_log_probs)
- assert gather_and_equal(item.values)
- assert gather_and_equal(item.reward)
- assert gather_and_equal(item.advantages)
- assert gather_and_equal(item.action_mask)
- assert gather_and_equal(item.attention_mask)
-
- # dataloader of each rank should have the same size and different batch
- dataloader = strategy.setup_dataloader(replay_buffer)
- dataloader_size = torch.tensor([len(dataloader)], device='cuda')
- assert gather_and_equal(dataloader_size)
- for experience in dataloader:
- assert not gather_and_equal(experience.sequences)
- assert not gather_and_equal(experience.action_log_probs)
- assert not gather_and_equal(experience.values)
- assert not gather_and_equal(experience.reward)
- assert not gather_and_equal(experience.advantages)
- # action mask and attention mask may be same
-
-
-def run_dist(rank, world_size, port, strategy):
- os.environ['RANK'] = str(rank)
- os.environ['LOCAL_RANK'] = str(rank)
- os.environ['WORLD_SIZE'] = str(world_size)
- os.environ['MASTER_ADDR'] = 'localhost'
- os.environ['MASTER_PORT'] = str(port)
- run_test_data(strategy)
-
-
-@pytest.mark.skip
-@pytest.mark.dist
-@pytest.mark.parametrize('world_size', [2])
-@pytest.mark.parametrize('strategy', ['ddp', 'colossalai'])
-@rerun_if_address_is_in_use()
-def test_data(world_size, strategy):
- spawn(run_dist, world_size, strategy=strategy)
-
-
-if __name__ == '__main__':
- test_data(2, 'colossalai')
diff --git a/applications/Chat/tests/test_dataset.py b/applications/Chat/tests/test_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec61bbb13fd7aa511208904c8bbef0c228af3d25
--- /dev/null
+++ b/applications/Chat/tests/test_dataset.py
@@ -0,0 +1,241 @@
+import json
+import os
+import tempfile
+from typing import Optional
+
+import pytest
+import torch
+from coati.dataset.prompt_dataset import PromptDataset
+from coati.dataset.reward_dataset import HhRlhfDataset, RmStaticDataset
+from coati.dataset.sft_dataset import IGNORE_INDEX, SFTDataset, SupervisedDataset
+from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
+from datasets import load_dataset
+from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, PreTrainedTokenizer
+from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
+
+SFT_DATASET = [
+ {
+ "instruction": "Provide a list of the top 10 most popular mobile games in Asia",
+ "input": "",
+ "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
+ "id": 0,
+ },
+ {
+ "instruction": "Please provide an action plan for reducing carbon footprint on a corporate level",
+ "input": "",
+ "output": "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.",
+ "id": 1,
+ },
+ {
+ "instruction": "Write a persuasive email to your boss explaining why you should have a pay raise",
+ "input": "",
+ "output": "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]",
+ "id": 2,
+ },
+]
+
+PROMPT_DATASET = [
+ {
+ "instruction": 'Edit this paragraph to make it more concise: "Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends."',
+ "id": 0,
+ },
+ {"instruction": "Write a descriptive paragraph about a memorable vacation you went on", "id": 1},
+ {"instruction": "Write a persuasive essay arguing why homework should be banned in schools", "id": 2},
+ {"instruction": "Create a chart comparing the statistics on student debt in the United States.", "id": 3},
+]
+
+
+def make_tokenizer(model: str):
+ if model == "gpt2":
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
+ tokenizer.pad_token = tokenizer.eos_token
+ elif model == "bloom":
+ tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
+ tokenizer.pad_token = tokenizer.eos_token
+ elif model == "opt":
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
+ tokenizer.pad_token = tokenizer.eos_token
+ elif model == "llama":
+ tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
+ tokenizer.pad_token = tokenizer.unk_token
+ elif model == "chatglm":
+ tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
+ else:
+ raise ValueError(f"Unsupported model '{model}'")
+ return tokenizer
+
+
+def check_content(input_ids_stripped: torch.Tensor, tokenizer: PreTrainedTokenizer, model: str):
+ if model == "opt":
+ # NOTE: Contrary to GPT2, OPT adds the EOS token to the beginning of every prompt.
+ assert input_ids_stripped[0] == tokenizer.eos_token_id
+ input_ids_stripped = input_ids_stripped[1:]
+ elif model == "llama":
+ assert input_ids_stripped[0] == tokenizer.bos_token_id
+ input_ids_stripped = input_ids_stripped[1:]
+ elif model == "chatglm":
+ assert input_ids_stripped[0] == tokenizer.bos_token_id
+ assert input_ids_stripped[-1] == tokenizer.eos_token_id
+ input_ids_stripped = input_ids_stripped[1:-1]
+ assert torch.all(input_ids_stripped != tokenizer.pad_token_id)
+ assert torch.all(input_ids_stripped != tokenizer.bos_token_id)
+ assert torch.all(input_ids_stripped != tokenizer.eos_token_id)
+ assert input_ids_stripped != tokenizer.sep_token_id
+ assert input_ids_stripped != tokenizer.cls_token_id
+ if model == "chatglm":
+ assert torch.all(input_ids_stripped != tokenizer.mask_token_id)
+ else:
+ assert input_ids_stripped != tokenizer.mask_token_id
+
+
+@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
+@pytest.mark.parametrize("max_length", [32, 1024])
+@pytest.mark.parametrize("max_datasets_size", [2])
+def test_prompt_dataset(model: str, max_datasets_size: int, max_length: int):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ dataset_name = "prompt_dataset.json"
+ with open(os.path.join(tmp_dir, dataset_name), "w") as f:
+ json.dump(PROMPT_DATASET, f)
+ tokenizer = make_tokenizer(model)
+ assert tokenizer.padding_side in ("left", "right")
+ prompt_dataset = PromptDataset(
+ data_path=os.path.join(tmp_dir, dataset_name),
+ tokenizer=tokenizer,
+ max_datasets_size=max_datasets_size,
+ max_length=max_length,
+ )
+ assert len(prompt_dataset) == min(max_datasets_size, len(PROMPT_DATASET))
+ for i in range(len(prompt_dataset)):
+ assert isinstance(prompt_dataset[i], dict)
+ assert list(prompt_dataset[i].keys()) == ["input_ids", "attention_mask"]
+ input_ids = prompt_dataset[i]["input_ids"]
+ attention_mask = prompt_dataset[i]["attention_mask"]
+ attention_mask = attention_mask.bool()
+ assert input_ids.shape == attention_mask.shape == torch.Size([max_length])
+ assert torch.all(input_ids[torch.logical_not(attention_mask)] == tokenizer.pad_token_id)
+ check_content(input_ids.masked_select(attention_mask), tokenizer, model)
+
+
+@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
+@pytest.mark.parametrize(
+ ["dataset_path", "subset"], [("Anthropic/hh-rlhf", "harmless-base"), ("Dahoas/rm-static", None)]
+)
+@pytest.mark.parametrize("max_datasets_size", [32])
+@pytest.mark.parametrize("max_length", [32, 1024])
+def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], max_datasets_size: int, max_length: int):
+ data = load_dataset(dataset_path, data_dir=subset)
+ assert max_datasets_size <= len(data["train"]) and max_datasets_size <= len(data["test"])
+ train_data = data["train"].select(range(max_datasets_size))
+ test_data = data["test"].select(range(max_datasets_size))
+ tokenizer = make_tokenizer(model)
+ assert tokenizer.padding_side in ("left", "right")
+
+ if dataset_path == "Anthropic/hh-rlhf":
+ train_dataset = HhRlhfDataset(train_data, tokenizer, max_length)
+ test_dataset = HhRlhfDataset(test_data, tokenizer, max_length)
+ elif dataset_path == "Dahoas/rm-static":
+ train_dataset = RmStaticDataset(train_data, tokenizer, max_length)
+ test_dataset = RmStaticDataset(test_data, tokenizer, max_length)
+ else:
+ raise ValueError(f'Unsupported dataset "{dataset_path}"')
+
+ assert len(train_dataset) == len(test_dataset) == max_datasets_size
+ for i in range(max_datasets_size):
+ chosen_ids, c_mask, reject_ids, r_mask = train_dataset[i]
+ assert chosen_ids.shape == c_mask.shape == reject_ids.shape == r_mask.shape == torch.Size([max_length])
+ c_mask = c_mask.to(torch.bool)
+ r_mask = r_mask.to(torch.bool)
+ if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id:
+ check_content(chosen_ids.masked_select(c_mask)[:-1], tokenizer, model)
+ assert torch.all(chosen_ids.masked_select(torch.logical_not(c_mask)) == tokenizer.pad_token_id)
+ else:
+ check_content(chosen_ids.masked_select(c_mask), tokenizer, model)
+ assert torch.all(c_mask)
+ if reject_ids.masked_select(r_mask)[-1] == tokenizer.eos_token_id:
+ check_content(reject_ids.masked_select(r_mask)[:-1], tokenizer, model)
+ assert torch.all(reject_ids.masked_select(torch.logical_not(r_mask)) == tokenizer.pad_token_id)
+ else:
+ check_content(reject_ids.masked_select(r_mask), tokenizer, model)
+ assert torch.all(r_mask)
+
+ chosen_ids, c_mask, reject_ids, r_mask = test_dataset[i]
+ assert chosen_ids.shape == c_mask.shape == reject_ids.shape == r_mask.shape == torch.Size([max_length])
+ c_mask = c_mask.to(torch.bool)
+ r_mask = r_mask.to(torch.bool)
+ if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id:
+ check_content(chosen_ids.masked_select(c_mask)[:-1], tokenizer, model)
+ assert torch.all(chosen_ids.masked_select(torch.logical_not(c_mask)) == tokenizer.pad_token_id)
+ else:
+ check_content(chosen_ids.masked_select(c_mask), tokenizer, model)
+ assert torch.all(c_mask)
+ if reject_ids.masked_select(r_mask)[-1] == tokenizer.eos_token_id:
+ check_content(reject_ids.masked_select(r_mask)[:-1], tokenizer, model)
+ assert torch.all(reject_ids.masked_select(torch.logical_not(r_mask)) == tokenizer.pad_token_id)
+ else:
+ check_content(reject_ids.masked_select(r_mask), tokenizer, model)
+ assert torch.all(r_mask)
+
+
+@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama", "chatglm"])
+@pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None])
+@pytest.mark.parametrize("max_dataset_size", [2])
+@pytest.mark.parametrize("max_length", [32, 1024])
+def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size: int, max_length: int):
+ tokenizer = make_tokenizer(model)
+ if dataset_path == "yizhongw/self_instruct":
+ data = load_dataset(dataset_path, "super_natural_instructions")
+ train_data = data["train"].select(range(max_dataset_size))
+ sft_dataset = SFTDataset(train_data, tokenizer, max_length)
+ else:
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ dataset_name = "sft_dataset.json"
+ with open(os.path.join(tmp_dir, dataset_name), "w") as f:
+ json.dump(SFT_DATASET, f)
+ sft_dataset = SupervisedDataset(
+ tokenizer=tokenizer,
+ data_path=os.path.join(tmp_dir, dataset_name),
+ max_datasets_size=max_dataset_size,
+ max_length=max_length,
+ )
+ assert len(sft_dataset) == min(max_dataset_size, len(SFT_DATASET))
+
+ if isinstance(tokenizer, ChatGLMTokenizer):
+ for i in range(max_dataset_size):
+ assert isinstance(sft_dataset[i], dict)
+ assert list(sft_dataset[i].keys()) == ["input_ids", "labels"]
+ input_ids = sft_dataset[i]["input_ids"]
+ labels = sft_dataset[i]["labels"]
+ assert input_ids.shape == labels.shape == torch.Size([max_length])
+
+ ignore_mask = labels == IGNORE_INDEX
+ assert input_ids.masked_select(torch.logical_not(ignore_mask))[0] == tokenizer.bos_token_id
+ check_content(input_ids.masked_select(torch.logical_not(ignore_mask)), tokenizer, model)
+ return
+
+ for i in range(max_dataset_size):
+ assert isinstance(sft_dataset[i], dict)
+ assert list(sft_dataset[i].keys()) == ["input_ids", "labels", "attention_mask"]
+ input_ids = sft_dataset[i]["input_ids"]
+ labels = sft_dataset[i]["labels"]
+ attention_mask = sft_dataset[i]["attention_mask"].to(torch.bool)
+ assert input_ids.shape == labels.shape == attention_mask.shape == torch.Size([max_length])
+ if input_ids.masked_select(attention_mask)[-1] == tokenizer.eos_token_id:
+ check_content(input_ids.masked_select(attention_mask)[:-1], tokenizer, model)
+ assert torch.all(input_ids.masked_select(torch.logical_not(attention_mask)) == tokenizer.pad_token_id)
+ else:
+ check_content(input_ids.masked_select(attention_mask), tokenizer, model)
+ assert torch.all(attention_mask)
+ ignore_mask = labels == IGNORE_INDEX
+ prompt_mask = torch.logical_and(ignore_mask, attention_mask)
+ check_content(input_ids.masked_select(prompt_mask), tokenizer, model)
+ assert torch.all(input_ids.masked_select(ignore_mask ^ prompt_mask) == tokenizer.pad_token_id)
+
+
+if __name__ == "__main__":
+ test_sft_dataset(model="bloom", dataset_path="yizhongw/self_instruct", max_dataset_size=2, max_length=256)
+
+ test_reward_dataset(
+ model="gpt2", dataset_path="Anthropic/hh-rlhf", subset="harmless-base", max_datasets_size=8, max_length=256
+ )
+
+ test_prompt_dataset(model="opt", max_datasets_size=2, max_length=128)
diff --git a/applications/Chat/tests/test_experience.py b/applications/Chat/tests/test_experience.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9591259800d94e5758cbf31104ab2b114aa5a19
--- /dev/null
+++ b/applications/Chat/tests/test_experience.py
@@ -0,0 +1,130 @@
+import copy
+import os
+
+import pytest
+import torch
+import torch.distributed as dist
+from coati.experience_buffer import NaiveExperienceBuffer
+from coati.experience_maker import NaiveExperienceMaker
+from coati.models.base import RewardModel
+from coati.models.gpt import GPTActor, GPTCritic
+from coati.trainer.ppo import _set_default_generate_kwargs
+from coati.trainer.strategies import DDPStrategy, GeminiStrategy
+from coati.trainer.strategies.colossalai import LowLevelZeroStrategy
+from transformers.models.gpt2.configuration_gpt2 import GPT2Config
+
+from colossalai.testing import rerun_if_address_is_in_use, spawn
+
+GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
+
+
+def get_data(batch_size: int, seq_len: int = 10) -> dict:
+ input_ids = torch.randint(0, 50257, (batch_size, seq_len), device="cuda")
+ attention_mask = torch.ones_like(input_ids)
+ return dict(input_ids=input_ids, attention_mask=attention_mask)
+
+
+def gather_and_equal(tensor: torch.Tensor) -> bool:
+ world_size = dist.get_world_size()
+ outputs = [torch.empty_like(tensor) for _ in range(world_size)]
+ dist.all_gather(outputs, tensor.contiguous())
+ for t in outputs[1:]:
+ if not torch.equal(outputs[0], t):
+ return False
+ return True
+
+
+def make_and_consume_experience(strategy):
+ EXPERIENCE_BATCH_SIZE = 4
+ SAMPLE_BATCH_SIZE = 2
+
+ if strategy == "ddp":
+ strategy = DDPStrategy()
+ elif strategy == "colossalai-zero2":
+ strategy = LowLevelZeroStrategy()
+ elif strategy == "colossalai-gemini":
+ strategy = GeminiStrategy(placement_policy="static")
+ else:
+ raise ValueError(f'Unsupported strategy "{strategy}"')
+
+ with strategy.model_init_context():
+ actor = GPTActor(config=GPT_CONFIG).cuda()
+ critic = GPTCritic(config=GPT_CONFIG).cuda()
+
+ initial_model = GPTActor(config=GPT_CONFIG).cuda()
+ reward_model = RewardModel(model=copy.deepcopy(critic.model)).cuda()
+
+ actor, critic, initial_model, reward_model = strategy.prepare(actor, critic, initial_model, reward_model)
+
+ class MockTokenizer:
+ def __init__(self):
+ self.padding_side = "left"
+ self.eos_token_id = 0
+ self.pad_token_id = 0
+
+ tokenizer = MockTokenizer()
+ experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer)
+ data_buffer = NaiveExperienceBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False)
+
+ generate_kwargs = dict(do_sample=True, max_length=16)
+ generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
+
+ # experience of all ranks should be the same
+ for _ in range(2):
+ data = get_data(EXPERIENCE_BATCH_SIZE)
+ assert gather_and_equal(data["input_ids"])
+ assert gather_and_equal(data["attention_mask"])
+ experience = experience_maker.make_experience(**data, do_sample=True, max_length=16)
+ assert gather_and_equal(experience.sequences)
+ assert gather_and_equal(experience.action_log_probs)
+ assert gather_and_equal(experience.values)
+ assert gather_and_equal(experience.reward)
+ assert gather_and_equal(experience.advantages)
+ assert gather_and_equal(experience.action_mask)
+ assert gather_and_equal(experience.attention_mask)
+ data_buffer.append(experience)
+
+ # data buffer's data should be the same
+ buffer_size = torch.tensor([len(data_buffer)], device="cuda")
+ assert gather_and_equal(buffer_size)
+ for item in data_buffer.items:
+ assert gather_and_equal(item.sequences)
+ assert gather_and_equal(item.action_log_probs)
+ assert gather_and_equal(item.values)
+ assert gather_and_equal(item.reward)
+ assert gather_and_equal(item.advantages)
+ assert gather_and_equal(item.action_mask)
+ assert gather_and_equal(item.attention_mask)
+
+ # dataloader of each rank should have the same size and different batch
+ dataloader = strategy.setup_dataloader(data_buffer)
+ dataloader_size = torch.tensor([len(dataloader)], device="cuda")
+ assert gather_and_equal(dataloader_size)
+ for experience in dataloader:
+ assert not gather_and_equal(experience.sequences)
+ assert not gather_and_equal(experience.action_log_probs)
+ assert not gather_and_equal(experience.values)
+ assert not gather_and_equal(experience.reward)
+ assert not gather_and_equal(experience.advantages)
+ # action mask and attention mask may be same
+
+
+def run_dist(rank, world_size, port, strategy):
+ os.environ["RANK"] = str(rank)
+ os.environ["LOCAL_RANK"] = str(rank)
+ os.environ["WORLD_SIZE"] = str(world_size)
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = str(port)
+ make_and_consume_experience(strategy)
+
+
+@pytest.mark.dist
+@pytest.mark.parametrize("world_size", [2])
+@pytest.mark.parametrize("strategy", ["ddp", "colossalai-zero2", "colossalai-gemini"])
+@rerun_if_address_is_in_use()
+def test_experience(world_size, strategy):
+ spawn(run_dist, world_size, strategy=strategy)
+
+
+if __name__ == "__main__":
+ test_experience(2, "colossalai-zero2")
diff --git a/applications/Chat/tests/test_inference.sh b/applications/Chat/tests/test_inference.sh
new file mode 100755
index 0000000000000000000000000000000000000000..849db06e58abdf9893a544d70ce9312bc5129b7a
--- /dev/null
+++ b/applications/Chat/tests/test_inference.sh
@@ -0,0 +1,11 @@
+set -xue
+
+BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
+EXAMPLES_DIR=$BASE_DIR/examples
+
+echo "[Test]: testing inference ..."
+
+# HACK: skip llama due to oom
+for model in 'gpt2' 'bloom' 'opt'; do
+ python $EXAMPLES_DIR/inference.py --model $model
+done
diff --git a/applications/Chat/tests/test_models.py b/applications/Chat/tests/test_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2c22ac6a3b986f5d05f0dc1c54088a021c6d838
--- /dev/null
+++ b/applications/Chat/tests/test_models.py
@@ -0,0 +1,245 @@
+import copy
+from typing import Any, Callable, Dict, Tuple
+
+import pytest
+import torch
+import torch.nn as nn
+from coati.models.base import Actor, Critic, RewardModel, get_base_model
+from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
+from coati.models.chatglm import ChatGLMActor
+from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
+from coati.models.generation import generate
+from coati.models.gpt import GPTRM, GPTActor, GPTCritic
+from coati.models.llama import LlamaActor
+from coati.models.lora import LoraLinear, convert_to_lora_module
+from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
+from coati.models.opt import OPTRM, OPTActor, OPTCritic
+from coati.models.utils import calc_action_log_probs, masked_mean
+
+
+@pytest.mark.parametrize("batch_size", [4])
+@pytest.mark.parametrize("seq_len", [32])
+@pytest.mark.parametrize(
+ "actor_maker",
+ [
+ lambda: BLOOMActor(),
+ lambda: GPTActor(),
+ # HACK: skip llama due to long execution time
+ # lambda: LlamaActor(),
+ lambda: OPTActor(),
+ ],
+)
+@pytest.mark.parametrize(
+ "generate_kwargs",
+ [
+ {
+ "max_length": 64,
+ "use_cache": True,
+ "do_sample": True,
+ "temperature": 1.0,
+ "top_k": 50,
+ }
+ ],
+)
+def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]):
+ class MockTokenizer:
+ def __init__(self):
+ self.padding_side = "left"
+ self.eos_token_id = 0
+ self.pad_token_id = 0
+
+ actor = actor_maker()
+ input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda()
+ tokenizer = MockTokenizer()
+ sequences = generate(actor.cuda(), input_ids, tokenizer, **generate_kwargs)
+ assert sequences.shape == (batch_size, generate_kwargs["max_length"])
+
+
+def test_utils():
+ fn_input = {"tensor": torch.ones((10,)), "mask": torch.randint(0, 2, (10,))}
+ fn_output = masked_mean(dim=0, **fn_input)
+ assert fn_output.dim() == 0
+ assert torch.allclose(fn_output, torch.tensor(1.0))
+
+ batch_size = 4
+ seq_len = 32
+ num_labels = 10
+ num_actions = 2
+ fn_input = {
+ "logits": torch.randn((batch_size, seq_len, num_labels)),
+ "sequences": torch.randint(0, num_labels, (batch_size, seq_len)),
+ "num_actions": num_actions,
+ }
+ fn_output = calc_action_log_probs(**fn_input)
+ assert fn_output.shape == (batch_size, num_actions)
+
+
+@pytest.mark.parametrize("lora_rank", [4])
+@pytest.mark.parametrize("num_dim", [32])
+@pytest.mark.parametrize("num_layers", [4])
+def test_lora(lora_rank: int, num_dim: int, num_layers: int):
+ model = nn.ModuleList([nn.Linear(num_dim, num_dim) for _ in range(num_layers)])
+ lora_model = convert_to_lora_module(model, lora_rank)
+ assert isinstance(lora_model, nn.ModuleList)
+ for i in range(num_layers):
+ assert isinstance(lora_model[i], LoraLinear)
+ assert lora_model[i].lora_A.shape == (lora_rank, num_dim)
+ assert lora_model[i].lora_B.shape == (num_dim, lora_rank)
+
+ old_model = copy.deepcopy(lora_model)
+ for i in range(num_layers):
+ assert isinstance(lora_model[i], LoraLinear)
+ assert torch.allclose(old_model[i].weight, lora_model[i].weight)
+ assert torch.allclose(old_model[i].bias, lora_model[i].bias)
+ assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A)
+ optimizer = torch.optim.Adam(lora_model.parameters())
+ x = torch.randn(8, num_dim)
+ for i in range(num_layers):
+ x = lora_model[i](x)
+ loss = x.sum()
+ loss.backward()
+ optimizer.step()
+ for i in range(num_layers):
+ assert isinstance(lora_model[i], LoraLinear)
+ assert torch.allclose(old_model[i].weight, lora_model[i].weight)
+ assert torch.allclose(old_model[i].bias, lora_model[i].bias)
+ assert not torch.allclose(
+ old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A
+ )
+
+
+@pytest.mark.parametrize("batch_size", [8])
+@pytest.mark.parametrize("seq_len", [128])
+@pytest.mark.parametrize(
+ "models_maker",
+ [
+ lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()),
+ lambda: (GPTActor(), GPTCritic(), GPTRM()),
+ # HACK: skip llama due to long execution time
+ # lambda: (LlamaActor(), LlamaCritic(), LlamaRM()),
+ lambda: (OPTActor(), OPTCritic(), OPTRM()),
+ lambda: (ChatGLMActor(), None, None),
+ ],
+)
+@torch.no_grad()
+def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], batch_size: int, seq_len: int):
+ actor_input = {
+ "input_ids": torch.randint(0, 100, (batch_size, seq_len)),
+ "attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
+ }
+ critic_input = {
+ "sequences": torch.randint(0, 100, (batch_size, seq_len)),
+ "attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
+ }
+ rm_input = {
+ "sequences": torch.randint(0, 100, (batch_size, seq_len)),
+ "attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
+ }
+
+ actor, critic, rm = models_maker()
+ if isinstance(actor, ChatGLMActor):
+ actor = actor.float()
+ tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
+ chatglm_special_token = torch.tensor([tokenizer.gmask_token_id, tokenizer.bos_token_id]).repeat(batch_size, 1)
+ actor_input = {
+ "input_ids": torch.cat(
+ (
+ torch.randint(0, 100, (batch_size, seq_len // 2)),
+ chatglm_special_token,
+ torch.randint(0, 100, (batch_size, seq_len // 2 - 2)),
+ ),
+ dim=1,
+ ),
+ "attention_mask": torch.randint(0, 2, (batch_size, 1, seq_len, seq_len)),
+ }
+ assert isinstance(actor, Actor)
+ get_base_model(actor)
+ actor_output = actor(**actor_input)
+ assert actor_output.logits.shape[:2] == (batch_size, seq_len)
+
+ if critic:
+ assert isinstance(critic, Critic)
+ get_base_model(critic)
+ critic_output = critic(**critic_input)
+ assert critic_output.shape == (batch_size,)
+
+ if rm:
+ assert isinstance(rm, RewardModel)
+ get_base_model(rm)
+ rm_output = rm(**rm_input)
+ assert rm_output.shape == (batch_size,)
+
+
+@pytest.mark.parametrize("batch_size", [16])
+@pytest.mark.parametrize("seq_len", [128])
+@pytest.mark.parametrize("num_labels", [100])
+def test_loss(batch_size: int, seq_len: int, num_labels: int):
+ loss = GPTLMLoss()
+ loss_input = {
+ "logits": torch.randn(batch_size, seq_len, num_labels),
+ "labels": torch.randint(0, num_labels, (batch_size, seq_len)),
+ }
+ loss(**loss_input)
+
+ loss = PolicyLoss()
+ loss_input = {
+ "log_probs": torch.randn(
+ batch_size,
+ ),
+ "old_log_probs": torch.randn(
+ batch_size,
+ ),
+ "advantages": torch.randn(
+ batch_size,
+ ),
+ }
+ loss(**loss_input)
+
+ loss = ValueLoss()
+ loss_input = {
+ "values": torch.randn(
+ batch_size,
+ ),
+ "old_values": torch.randn(
+ batch_size,
+ ),
+ "reward": torch.randn(
+ batch_size,
+ ),
+ }
+ loss(**loss_input)
+
+ loss = LogSigLoss()
+ loss_input = {
+ "chosen_reward": torch.randn(
+ batch_size,
+ ),
+ "reject_reward": torch.randn(
+ batch_size,
+ ),
+ }
+ loss(**loss_input)
+
+ loss = LogExpLoss()
+ loss_input = {
+ "chosen_reward": torch.randn(
+ batch_size,
+ ),
+ "reject_reward": torch.randn(
+ batch_size,
+ ),
+ }
+ loss(**loss_input)
+
+
+if __name__ == "__main__":
+ generate_kwargs = dict(max_length=40, use_cache=True, do_sample=True, temperature=1.0, top_k=50)
+ test_generation(lambda: LlamaActor(), batch_size=4, seq_len=32, generate_kwargs=generate_kwargs)
+
+ test_utils()
+
+ test_lora(lora_rank=2, num_dim=8, num_layers=2)
+
+ test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128)
+
+ test_loss(batch_size=8, seq_len=128, num_labels=100)
diff --git a/applications/Chat/tests/test_train.sh b/applications/Chat/tests/test_train.sh
new file mode 100755
index 0000000000000000000000000000000000000000..68fca7fbf8c0434cd0883fd0810fc6be74d3fe38
--- /dev/null
+++ b/applications/Chat/tests/test_train.sh
@@ -0,0 +1,233 @@
+#!/usr/bin/env bash
+
+set_n_least_used_CUDA_VISIBLE_DEVICES() {
+ local n=${1:-"9999"}
+ echo "GPU Memory Usage:"
+ local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
+ tail -n +2 |
+ nl -v 0 |
+ tee /dev/tty |
+ sort -g -k 2 |
+ awk '{print $1}' |
+ head -n $n)
+ export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
+ echo "Now CUDA_VISIBLE_DEVICES is set to:"
+ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
+}
+
+set_n_least_used_CUDA_VISIBLE_DEVICES 4
+
+set -xu
+
+if [ -z "$SFT_DATASET" ]; then
+ echo "Please set \$SFT_DATASET to the path to sft dataset."
+ exit 1
+fi
+
+if [ -z "$PROMPT_DATASET" ]; then
+ echo "Please set \$PROMPT_DATASET to the path to prompts csv."
+ exit 1
+fi
+
+if [ -z "$PRETRAIN_DATASET" ]; then
+ echo "Please set \$PRETRAIN_DATASET to the path to alpaca data."
+ exit 1
+fi
+
+NUM_RETRY=3
+BASE_DIR=$(dirname $(dirname $(realpath $BASH_SOURCE)))
+EXAMPLES_DIR=$BASE_DIR/examples
+MODELS_DIR=$BASE_DIR/examples/models_config
+MODELS=('gpt2' 'bloom' 'opt' 'llama')
+STRATEGIES=('ddp' 'colossalai_gemini' 'colossalai_zero2')
+
+
+export OMP_NUM_THREADS=8
+
+# install requirements
+pip install -r $EXAMPLES_DIR/requirements.txt
+
+python $EXAMPLES_DIR/download_model.py --model-dir $MODELS_DIR --config-only
+
+get_pretrain() {
+ local model=$1
+ if [[ $model == "gpt2" ]]; then
+ echo "gpt2"
+ elif [[ $model == "bloom" ]]; then
+ echo "bigscience/bloom-560m"
+ elif [[ $model == "opt" ]]; then
+ echo "facebook/opt-350m"
+ else
+ echo "Unknown model $model"
+ exit 1
+ fi
+}
+
+random_choice() {
+ local arr=("$@")
+ local len=${#arr[@]}
+ local idx=$((RANDOM % len))
+ echo ${arr[$idx]}
+}
+
+echo "[Test]: testing sft ..."
+
+# FIXME: This is a hack to skip tests that are not working
+# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
+# - llama-*: These tests can be passed locally, skipped for long execution time
+# - *-gemini: Gemini plugin does not support `from_pretrained` yet
+SKIPPED_TESTS=(
+ "gpt2-ddp"
+ "llama-ddp"
+ "llama-colossalai_gemini"
+ "llama-colossalai_zero2"
+)
+
+GRAD_CKPTS=('' '--grad_checkpoint')
+for lora_rank in '0'; do
+ for model in ${MODELS[@]}; do
+ strategies=($(shuf -e "${STRATEGIES[@]}"))
+ for strategy in ${strategies[@]}; do
+ if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then
+ echo "[Test]: Skipped $model-$strategy-$lora_rank"
+ continue
+ elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy " ]]; then
+ echo "[Test]: Skipped $model-$strategy"
+ continue
+ fi
+ pretrain=$(get_pretrain $model)
+ pretrain_model=""
+ if [[ $lora_rank -gt 0 ]]; then
+ pretrain_model="--pretrain $pretrain"
+ fi
+ grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
+ for i in $(seq $NUM_RETRY); do
+ echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
+ torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_sft.py \
+ $pretrain_model --tokenizer $MODELS_DIR/$model \
+ --model $model --strategy $strategy --lora_rank $lora_rank $grad_ckpt \
+ --dataset $SFT_DATASET --max_datasets_size 8 \
+ --max_epochs 1 --batch_size 1 --accumulation_steps 1 --lr 1e-8 \
+ --save_path $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank}
+ passed=$?
+ if [ $passed -eq 0 ]; then
+ break
+ fi
+ done
+ if [ $passed -ne 0 ]; then
+ echo "[Test]: Failed $model-$strategy-$lora_rank"
+ exit 1
+ fi
+ done
+ done
+done
+
+echo "[Test]: testing reward model ..."
+
+# FIXME: This is a hack to skip tests that are not working
+# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
+# - llama-*: These tests can be passed locally, skipped for long execution time
+# - *-gemini: Gemini plugin does not support `from_pretrained` yet
+SKIPPED_TESTS=(
+ "gpt2-ddp"
+ "llama-ddp"
+ "llama-colossalai_gemini"
+ "llama-colossalai_zero2"
+)
+
+LOSS_FNS=('log_sig' 'log_exp')
+DATASETS=('Anthropic/hh-rlhf' 'Dahoas/rm-static')
+for lora_rank in '0'; do
+ for model in ${MODELS[@]}; do
+ strategies=($(shuf -e "${STRATEGIES[@]}"))
+ for strategy in ${strategies[@]}; do
+ if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then
+ echo "[Test]: Skipped $model-$strategy-$lora_rank"
+ continue
+ elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy " ]]; then
+ echo "[Test]: Skipped $model-$strategy"
+ continue
+ fi
+ pretrain=$(get_pretrain $model)
+ pretrain_model=""
+ if [[ $lora_rank -gt 0 ]]; then
+ pretrain_model="--pretrain $pretrain"
+ fi
+ loss_fn=$(random_choice "${LOSS_FNS[@]}")
+ dataset=$(random_choice "${DATASETS[@]}")
+ subset=$(if [[ $dataset == "Dahoas/rm-static" ]]; then echo "None"; else echo "harmless-base"; fi)
+ for i in $(seq $NUM_RETRY); do
+ echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
+ torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_reward_model.py \
+ $pretrain_model --tokenizer $MODELS_DIR/$model \
+ --dataset $dataset --subset $subset --max_datasets_size 8 \
+ --model $model --strategy $strategy --lora_rank $lora_rank \
+ --loss_fn $loss_fn --batch_size 1 --lr 1e-8 \
+ --save_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt
+ passed=$?
+ if [ $passed -eq 0 ]; then
+ break
+ fi
+ done
+ if [ $passed -ne 0 ]; then
+ echo "[Test]: Failed to train reward model $model-$strategy-$lora_rank"
+ exit 1
+ fi
+ done
+ done
+done
+
+echo "[Test]: testing RLHF ..."
+
+# FIXME: This is a hack to skip tests that are not working
+# - gpt2-ddp: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
+# - llama-*: These tests can be passed locally, skipped for long execution time
+# - *-gemini: Gemini plugin does not support `from_pretrained` yet
+SKIPPED_TESTS=(
+ "gpt2-ddp"
+ "llama-ddp"
+ "llama-colossalai_gemini"
+ "llama-colossalai_zero2"
+)
+
+for model in ${MODELS[@]}; do
+ for lora_rank in '0'; do
+ strategies=($(shuf -e "${STRATEGIES[@]}"))
+ for strategy in ${strategies[@]}; do
+ if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then
+ echo "[Test]: Skipped $model-$strategy-$lora_rank"
+ continue
+ elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy " ]]; then
+ echo "[Test]: Skipped $model-$strategy"
+ continue
+ fi
+ rm_pretrain=$(get_pretrain $model)
+ rm_pretrain_model=""
+ if [[ $lora_rank -gt 0 ]]; then
+ rm_pretrain_model="--rm_pretrain $rm_pretrain"
+ fi
+ for i in $(seq $NUM_RETRY); do
+ echo "[Test]: $model-$strategy-$lora_rank, attempt $i"
+ torchrun --standalone --nproc_per_node=4 $EXAMPLES_DIR/train_prompts.py \
+ --prompt_dataset $PROMPT_DATASET --pretrain_dataset $PRETRAIN_DATASET --max_datasets_size 32 \
+ --strategy $strategy --model $model --tokenizer $MODELS_DIR/$model \
+ --num_episodes 1 --num_collect_steps 1 --num_update_steps 1 --lr 1e-8 \
+ --experience_batch_size 2 --train_batch_size 1 --lora_rank $lora_rank \
+ --pretrain $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} \
+ $rm_pretrain_model --rm_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt \
+ --save_path $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts
+ passed=$?
+ if [ $passed -eq 0 ]; then
+ break
+ fi
+ done
+ if [ $passed -ne 0 ]; then
+ echo "[Test]: Failed to train RLHF $model-$strategy-$lora_rank"
+ exit 1
+ fi
+ done
+ rm -rf $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank}
+ rm $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt
+ done
+done
+rm -rf $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts
diff --git a/applications/Colossal-LLaMA-2/README.md b/applications/Colossal-LLaMA-2/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..34967c04360c7a436b754d76d76e12fb4d0e9237
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/README.md
@@ -0,0 +1,390 @@
+
+
+
+
+
+
+## Table of Contents
+- [News](#news)
+- [Colossal-LLaMA-2-7B](#colossal-llama-2-7b)
+ - [Performance Evaluation](#performance-evaluation)
+ - [Examples](#examples)
+ - [Training Logs](#training-logs)
+ - [Import from Transformers](#import-from-transformers)
+- [Usage](#usage)
+ - [Install](#install)
+ - [How to run](#how-to-run)
+- [Technical Insight](#technical-insights)
+ - [Data](#data)
+ - [Tokenizer](#tokenizer)
+ - [Training Strategy](#training-strategy)
+ - [Bridging Any Domain-specific Large Models](#bridging-any-domain-specific-large-models)
+- [Citations](#citations)
+
+## News
+* [2023/09] [One Half-Day of Training Using a Few Hundred Dollars Yields Similar Results to Mainstream Large Models, Open-Source and Commercial-Free Domain-Specific Llm Solution](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution)
+[[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA-2)
+[[blog]](https://www.hpc-ai.tech/blog/one-half-day-of-training-using-a-few-hundred-dollars-yields-similar-results-to-mainstream-large-models-open-source-and-commercial-free-domain-specific-llm-solution)
+[[model weights]](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base)
+
+## Colossal-LLaMA-2-7B
+The [Colossal-AI](https://github.com/hpcaitech/ColossalAI) team has introduced the open-source model **Colossal-LLaMA-2-7B-base**. This model, a derivation of LLaMA-2, has undergone continual pre-training involving approximately 8.5 billion tokens over a duration of 15 hours with 64 A800 GPUs. At a cost of **less than $1,000**, you can achieve results **similar to those that cost millions of dollars to pretrain from scratch**. It is licensed under the LLaMA-2 license and [Apache 2.0 License](https://github.com/hpcaitech/ColossalAI/blob/main/LICENSE) **without any additional commercial use restrictions**. This solution can also be used to build models of specific domain knowledge or tasks.
+
+Colossal-LLaMA-2-7B-base is designed to accommodate both the Chinese and English languages, featuring an expansive context window spanning 4096 tokens. Remarkably, it has exhibited exceptional performance when benchmarked against models of equivalent scale in standard Chinese and English evaluation metrics, including C-Eval and MMLU, among others.
+
+❗️**Important notice**:
+* All training data used for this project is collected from well-known public dataset.
+* We do not use any testing data from the evaluation benchmarks for training.
+
+### Performance Evaluation
+We conducted comprehensive evaluation on 4 dataset and compare our Colossal-Llama-2-7b-base model with various models.
+
+* We use 5-shot for MMLU and calculate scores based on the logits of first predicted token.
+* We use 5-shot for CMMLU and calculate scores based on the logits of first predicted token.
+* We use 5-shot for AGIEval and only calculate scores for 4-choice questions using a combination metric of exact match and the logits of first predicted token. If any of the exact match or logits of first predicted token is correct, the model will get the score.
+* We use 0-shot for GAOKAO-Bench and only calculate scores for 4-choice questions based on the logits of first predicted token.
+The generation config for all dataset is greedy search.
+* We also provided CEval scores from its lastest leaderboard or the official repository of the model.
+
+| | Backbone | Tokens Consumed | | MMLU | CMMLU | AGIEval | GAOKAO | CEval |
+| :----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :-----: | :----: | :----: | :------------------------------: |
+| | | - | | 5-shot | 5-shot | 5-shot | 0-shot | 5-shot |
+| Baichuan-7B | - | 1.2T | | 42.32 (42.30) | 44.53 (44.02) | 38.72 | 36.74 | 42.80 |
+| Baichuan-13B-Base | - | 1.4T | | 50.51 (51.60) | 55.73 (55.30) | 47.20 | 51.41 | 53.60 |
+| Baichuan2-7B-Base | - | 2.6T | | 46.97 (54.16) | 57.67 (57.07) | 45.76 | 52.60 | 54.00 |
+| Baichuan2-13B-Base | - | 2.6T | | 54.84 (59.17) | 62.62 (61.97) | 52.08 | 58.25 | 58.10 |
+| ChatGLM-6B | - | 1.0T | | 39.67 (40.63) | 41.17 (-) | 40.10 | 36.53 | 38.90 |
+| ChatGLM2-6B | - | 1.4T | | 44.74 (45.46) | 49.40 (-) | 46.36 | 45.49 | 51.70 |
+| InternLM-7B | - | 1.6T | | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 |
+| Qwen-7B (original) | - | 2.2T | | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 |
+| | | | | | | | | |
+| Llama-2-7B | - | 2.0T | | 44.47 (45.30) | 32.97 (-) | 32.60 | 25.46 | - |
+| Linly-AI/Chinese-LLaMA-2-7B-hf | Llama-2-7B | 1.0T | | 37.43 | 29.92 | 32.00 | 27.57 | - |
+| wenge-research/yayi-7b-llama2 | Llama-2-7B | - | | 38.56 | 31.52 | 30.99 | 25.95 | - |
+| ziqingyang/chinese-llama-2-7b | Llama-2-7B | - | | 33.86 | 34.69 | 34.52 | 25.18 | 34.2 |
+| TigerResearch/tigerbot-7b-base | Llama-2-7B | 0.3T | | 43.73 | 42.04 | 37.64 | 30.61 | - |
+| LinkSoul/Chinese-Llama-2-7b | Llama-2-7B | - | | 48.41 | 38.31 | 38.45 | 27.72 | - |
+| FlagAlpha/Atom-7B | Llama-2-7B | 0.1T | | 49.96 | 41.10 | 39.83 | 33.00 | - |
+| IDEA-CCNL/Ziya-LLaMA-13B-v1.1 | Llama-13B | 0.11T | | 50.25 | 40.99 | 40.04 | 30.54 | - |
+| | | | | | | | | |
+| **Colossal-LLaMA-2-7b-base** | Llama-2-7B | **0.0085T** | | 53.06 | 49.89 | 51.48 | 58.82 | 50.2 |
+
+> The score in parentheses corresponds to the scores in the official repository of the model.
+>
+> We use zero-shot for ChatGLM models.
+>
+> Qwen-7B is now inaccessible in Hugging Face, we are using the latest version of it before it was made inaccessible. Only for dataset MMLU, the prompt would be "xxx Answer:"(remove the space after ":") and we calculate the logits over " A", " B", " C" and " D" for Qwen-7B. Qwen-7B tends to be much more deterministic than other models. For example, the logits over " A" can be `-inf` and softmax would be exact `0`.
+>
+> For other models and other dataset, we calculate logits over "A", "B", "C" and "D".
+
+❗️ More details of the evaluation methods and reproduction of the results, please refer to [ColossalEval](https://github.com/hpcaitech/ColossalAI/tree/main/applications/ColossalEval).
+
+### Examples
+| Question Type | Question | Colossal-LLaMA-2-7b-base |
+| :------: | :-----------------------: | :----------- |
+| Knowledge Quiz
(Tourism) | 请问武汉的最佳的旅行季节是什么时候? | 3-4月份到武汉赏樱,10月份到武汉看黄陂枫树。武汉的最佳的旅行季节是春秋,3-4月份到武汉赏樱,10月份到武汉看黄陂枫树。武汉,简称“汉”别称江城,是湖北省省会,中部六省唯一的副省级市,特大城市,中国中部地区的中心城市,全国重要的工业基地、科教基地和综合交通枢纽。 |
+| Knowledge Quiz
(Computer Science) | 在计算机中,什么是RAM?与ROM有什么区别? | RAM是随机存取内存,可以随时读写,但是断电之后会丢失数据。存储数据比较快,但是需要不断的进行通电供电。ROM是只读内存,只能读取数据,不能写入数据。断电之后,存储的数据也不会丢失,但是需要不断进行通电。两者的价格和性能上会有区别,而且在使用的时候,RAM需要不断的供电,而ROM断电数据不丢失。 |
+| Knowledge Quiz
(Literature) | 古代文学与现代文学有哪些区别和共通之处? | 古代文学与现代文学有三点区别,现代文学更加通俗易懂,具有可读性,而古代文学相对来说更加注重形式,缺乏可读性。现代文学和古代文学,同属于中华文学大花园中的一枝花蕾。但是它们由于历史的原因,有着不同的特点,具体表现在,现代文学有更丰富的描写方法,更强烈的主观色彩和更突出的抒情色彩,而古代文学的主要特征是在创作风格上的多样化。 |
+| Emotional Analysis | 请判断这个句子的情绪是否正面:我迫不及待想见到她!!!”。 | 这一句是情绪是正向的。是在期待见面的喜悦情绪。 |
+| Translation | 将以下句子翻译成英语。 我喜欢看电影和读书。 | I like watching movies and reading books. |
+| Logical Reasoning | All books have the same weight, 10 books weigh 5kg, what is the weight of 2 books? | To find the weight of 2 books, we can use the following steps:
1. Determine the weight of 1 book.
2. Multiply the weight of 1 book by the number of books.
Step 1: Determine the weight of 1 book.
Since 10 books weigh 5kg, we can divide 5kg by 10 to find the weight of 1 book.
5kg ÷ 10 = 0.5kg
Step 2: Multiply the weight of 1 book by the number of books.
To find the weight of 2 books, we multiply the weight of 1 book by 2.
0.5kg × 2 = 1kg
So, the weight of 2 books is 1kg. |
+| Information Extraction | The company's annual conference, featuring keynote speakers and exclusive product launches, will be held at the Los Angeles Convention Center from October 20th to October 23rd, 2021. Extract the date mentioned in the above sentence. | The date mentioned in the sentence is October 20th to October 23rd, 2021. |
+| Error Correction | Please correct the syntax error and restructure the sentence: "Went to the store me and my friend." | The given sentence "Went to the store me and my friend." is incorrect due to a syntax error and it lacks a subject. The corrected version with proper grammar and sentence structure would be: "My friend and I went to the store." |
+
+❗️ More examples of question answering, please refer to [Colossal-LLaMA-2-7B-base Examples](docs/example.md).
+
+### Training Logs
+We also recorded the training logs for the experiment
+
+
+
+
+
+
+
+
+
+### Import from Transformers (Inference)
+To load Colossal-LLaMA-2-7B-base model using Transformers, use the following code:
+```Python
+from transformers import AutoModelForCausalLM, AutoTokenizer
+model = AutoModelForCausalLM.from_pretrained("hpcai-tech/Colossal-LLaMA-2-7b-base", device_map="auto", trust_remote_code=True)
+tokenizer = AutoTokenizer.from_pretrained("hpcai-tech/Colossal-LLaMA-2-7b-base", trust_remote_code=True)
+input = "离离原上草,"
+inputs = tokenizer(input, return_tensors='pt')
+inputs = inputs.to('cuda:0')
+pred = model.generate(**inputs,
+ max_new_tokens=256,
+ do_sample=True,
+ top_k=50,
+ top_p=0.95,
+ num_return_sequences=1)
+print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)[len(input):])
+```
+
+You can also download model weights from [🤗HuggingFace](https://huggingface.co/hpcai-tech/Colossal-LLaMA-2-7b-base).
+
+## Usage
+### Install
+
+#### 0. Pre-requisite
+1. This experiment was performed on 8 computing nodes with 64 A800 GPUs in total for LLaMA-2-7B (**about 1000 USD cost**). The nodes are connected with RDMA and GPUs within one node are fully connected with NVLink. The script was tested with CUDA 11.7, CUDA version requires 11.7 or higher. You can also complete it in about 5 days on a 8*A100/A800 server.
+
+2. PyTorch. The PyTorch version should be less than 2.0.0 and greater than 1.12.1.
+
+
+#### 1. Install required packages
+```
+cd Colossal-LLaMA-2
+pip install -r requirements.txt
+```
+#### 2. Install `xentropy`, `layer_norm` and `rotary`
+```bash
+git clone git@github.com:Dao-AILab/flash-attention.git
+# At the root folder
+cd csrc/xentropy && pip install .
+# At the root folder
+cd csrc/layer_norm && pip install .
+# At the root folder
+cd csrc/rotary && pip install .
+```
+
+### How to run
+
+#### 1. Init Tokenizer Preparation
+Initialize new tokenizer with additional Chinese tokens. Additional Chinese tokens are stored in `jsonl` format as follows:
+```json
+{"piece": "你好"}
+{"piece": "人工智能"}
+```
+Command to initialize new tokenizer:
+```bash
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION='python'
+python colossal_llama2/tokenizer/init_tokenizer.py \
+ --source_tokenizer_dir "" \
+ --target_tokenizer_dir "" \
+ --expand_tokens_file ".jsonl"
+```
+Here is details about CLI arguments:
+* Source tokenizer directory: `--source_tokenizer_dir`. Directory to the source tokenizer. It should at least contain three files: `special_tokens_map.json`, `tokenizer.model` and `tokenizer_config.json`.
+* Target tokenizer directory: `--target_tokenizer_dir`. Directory to the target tokenizer.
+* Tokens to be added: `--expand_tokens_file`. Additional tokens to be added to the tokenizer.
+
+#### 2. Init Model Preparation
+Initialize the new model checkpoint by calculating the mean values from the original model checkpoint.
+Command to initialize new model checkpoint:
+```bash
+python colossal_llama2/model/init_model.py \
+ --source_model_and_tokenizer_path "" \
+ --target_tokenizer_path "" \
+ --target_model_path ""
+```
+"" can be the same as "".
+
+Here is details about CLI arguments:
+* Source model and tokenizer path: `--source_model_and_tokenizer_path`. Source folder contains both model and tokenizer, for example, LLaMA-2 model in Hugging Face format.
+* Target tokenizer path: `--target_tokenizer_path`. Path to the new tokenizer folder generated from previous step.
+* Target model path: `--target_model_path`. Path to save the new model in Hugging Face format.
+
+❗️**Important**: Once you initialize the new model checkpoint, copy your new tokenizer files (`special_tokens_map.json`, `tokenizer.model` and `tokenizer_config.json`) to your new model folder.
+
+#### 3. Data Preparation
+Raw data should be formatted as `jsonl` format. Each data point should have the following fields:
+* `source` (str, compulsory): This part is ignored when calculating loss. Default can be empty.
+* `target` (str, compulsory): Loss will be calculated.
+* `category` (str, compulsory): Tags for each data point.
+
+Examples:
+```JSON
+{"source": "", "target": "Lionel Andrés Messi(Spanish pronunciation: [ljoˈnel anˈdɾes ˈmesi] (i); born 24 June 1987), also known as Leo Messi, is an Argentine professional footballer who plays as a forward for and captains both Major League Soccer club Inter Miami and the Argentina national team.", "category": "sports"}
+{"source": "猜谜语:一身卷卷细毛,吃的青青野草,过了数九寒冬,无私献出白毛。(打一动物)", "target": "白羊", "category": "riddle"}
+```
+You are allowed to customize the category tags or use `unknown` to define the category.
+
+Command to convert jsonl dataset to arrow format:
+```
+python prepare_pretrain_dataset.py \
+ --data_input_dirs ",," \
+ --tokenizer_dir "" \
+ --data_cache_dir "jsonl_to_arrow_cache" \
+ --data_jsonl_output_dir "spliced_tokenized_output_jsonl" \
+ --data_arrow_output_dir "spliced_tokenized_output_arrow" \
+ --max_length 4096 \
+ --num_spliced_dataset_bins 10
+```
+Here is details about CLI arguments:
+* Source data directory: `data_input_dirs`. Each `` can have multiple file in `jsonl` format.
+* Tokenzier directory: `tokenizer_dir`. Path to the tokenizer in Hugging Face format.
+* Data cache directory: `data_cache_dir`. Directory to store Hugging Face data cache. Default case will create `cache` folder locally.
+* Output directory for jsonl format: `data_jsonl_output_dir`. Output directory to store converted dataset in jsonl format.
+* Output directory for arrow format: `data_arrow_output_dir`. Output directory to store converted dataset in arrow format, which can be used for training directly.
+* Max length: `max_length`. Max length of spliced samples. Default value is 4096.
+* Number of bins for each category: `num_spliced_dataset_bins`. Number of bins for each category, used for bucket-based training.
+
+#### 4. Command Line Arguments for Training
+You can use `colossalai run` to launch multi-nodes training:
+```bash
+colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
+train.py --OTHER_CONFIGURATIONS
+```
+Here is a sample hostfile:
+```bash
+hostname1
+hostname2
+hostname3
+hostname4
+```
+Make sure master node can access all nodes (including itself) by ssh without password.
+
+Here is details about CLI arguments:
+* Pre-trained model path: `--pretrained`. Path to the pre-trained model in Hugging Face format.
+* Dataset path: `--dataset`. Path to the pre-tokenized dataset.
+* Booster plugin: `--plugin`. `gemini`, `gemini_auto`, `zero2`,`zero2_cpu` and `3d` are supported.For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins/).
+* Intermediate checkpoint to load: `--load_checkpoint`. Path to the intermediate checkpoint. Saved checkpoint contains the states for `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`. If `load_checkpoint` points to the `modelling` folder, only the model weights will be loaded without any other states to support multi-stage training.
+* Save interval: `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
+* Checkpoint directory: `--save_dir`. The directoty path to save checkpoint and intermediate states. Intermediate states include `lr_scheduler`, `optimizer`,`running_states.json` and `modelling`.
+* Tensorboard directory: `--tensorboard_dir`. The path to save tensorboard logs.
+* Configuration file: `--config_file`. The path to save the configuration file.
+* Number of epochs: `--num_epochs`. Number of training epochs. The default value is 1.
+* Micro batch size: `--micro_batch_size`. Batch size per GPU. The default value is 1.
+* Learning rate: `--lr`. The default value is 3e-4.
+* Max length: `--max_length`. Max context length. The default value is 4096.
+* Mixed precision: `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported.
+* Gradient clipping: `--gradient_clipping`. The default value is 1.0.
+* Weight decay: `-w`, `--weight_decay`. The default value is 0.1.
+* Warmup steps: `-s`, `--warmup_steps`. The default value is calcuated by 0.025 warmup ratio.
+* Gradient checkpointing: `--use_grad_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size.
+* Flash attention: `--use_flash_attn`. If you want to use flash attention, you must install `flash-attn` and related packages. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention.
+* Freeze non-embedding parameters: `--freeze_non_embeds_params`. Freeze non-embedding parameters. It can be helpful to align embeddings after extending vocabulary size.
+* Tensor parallelism size: `--tp`. TP size for 3d Parallelism. The default value is 1.
+* Zero stage: `--zero`. Zero stage for 3d Parallelism. The default value is 1.
+
+#### 5. Running Command
+An [example bash](train.example.sh) is also provided for the experiment. Here is the steps to run the experiment:
+* Create your own hostfile: `cp hostfile.example hostfile`.
+* Create your own bash: `cp train.example.sh train.sh`.
+* Add your real host ip or host name into the `hostfile`.
+* Update global variables and parameters in your `train.sh`.
+* Run the experiment by `bash train.sh`
+
+Here is the details about global variables for each experiment:
+* `PROJECT_NAME`: Project name for each experiment.
+* `PARENT_SAVE_DIR`: Parent folder to save model checkpoint.
+* `PARENT_TENSORBOARD_DIR`: Parent folder to save tensorboard logs.
+* `PARENT_CONFIG_FILE`: Parent folder to save configuration for each experiment.
+* `PRETRAINED_MODEL_PATH`: Path to the local pre-trained model checkpoint.
+* `dataset`: Paths to all prepared data. Typically, it's a list of subfolders within the output path of prepare data, `--data_arrow_output_dir`, and if there are multiple subfolders, please list them all. e.g.,
+```python
+declare -a dataset=(
+ "/part-00000"
+ "/part-00001"
+ "/part-00000"
+)
+```
+## Technical Insights
+In order to enhance LLaMA-2's capabilities for understanding and generating Chinese content, The [Colossal-AI](https://github.com/hpcaitech/ColossalAI) team proposes the continuation of pre-training the LLaMA-2 model using both Chinese and English corpora. The overall pipeline can be described as follows:
+
+
+
+
+
+### Data
+Large language models such as LLaMA-2 have undergone training using a heterogeneous blend of high-quality datasets, yielding promising outcomes. Enhancing LLaMA-2's performance for the Chinese corpus, while preserving its proficiency in English, critically hinges on two pivotal factors: the composition of the dataset, which encompasses both English and Chinese content, and the quality of each constituent dataset.
+
+The following figure shows the data processing pipeline conducted for Colossal-LLaMA-2.
+
+
+
+
+❗️**Important**: We will open-source our data-processing toolkit soon, stay tuned!
+
+### Tokenizer
+The original LLaMA-2 vacabulary comprises fewer than a thousand Chinese characters, thus proves inadequate for encoding comprehensive Chinese texts effectively. Secondly, the utilization of byte tokens presents a challenge for transformer encoders to capture the semantic nuances of Chinese characters.
+
+To address the above issues, we extend LLaMA-2 vocabulary from 32,000 to 69,104. To adapt the LLaMA-2 model for use with the Colossal-LLaMA-2 tokenizer, we initialize the new word embeddings by calculating the mean values from the original LLaMA-2 embeddings and subsequently append these new rows to the end of the original embedding matrices.
+
+Advantages of extending vocabulary size:
+* Improve the compression rate of string sequence encoding.
+* Enhance the integrity of information.
+* Enable encoded sequences to contain more valuable information, thereby theoretically enhancing the ability for chapter-level encoding.
+
+Advantages of large vocabulary size under low-resource settings:
+* The presence of numerous unused tokens can be attributed to the limited training dataset, where an excessive number of tokens might not have been effectively learned.
+* Excessive vocabulary expansion leads to an increase in embedding-related parameters, resulting in higher memory usage, which, in turn, affects the efficiency of the training process.
+
+To balance both sides, we finally construct our vocabulary with size 69,104. The following table below presents a comparison of various models at the 7B level.
+
+| Model | Vocabulary Size | Compression Rate | Average Length of Samples (token-level) |
+| :-----------: | :---------: | :----: | :----: |
+| Colossal-LLaMA-2 | 69104 | 0.659 | 73.682 |
+| LLaMA-2-7B | 32000 | 1.205 | 134.689 |
+| Atom-7B | 65000 | 0.634 | 70.915 |
+| Baichuan-7B | 64000 | 0.678 | 75.857 |
+| Baichuan2-7B-base | 125696 | 0.570 | 63.761 |
+| Chatglm2-6B | 64789 | 0.645 | 72.178 |
+| InternLM-7B | 103168 | 0.566 | 63.349 |
+| Qwen-7B | 151643 | 0.578 | 64.703 |
+| Tigerbot-7B-base | 60515 | 0.630 | 70.515 |
+| Yayi-7B-llama2 | 32005 | 1.214 | 135.689 |
+| Chinese-llama-2-7b | 55296 | 0.668 | 74.690 |
+| Chinese-Falcon-7B | 90046 | 0.669 | 74.858 |
+| LinkSoul-Chinese-Llama-2-7b | 40076 | 0.958 | 107.089 |
+| Ziya-LLaMA-13B-v1.1 | 39410 | 0.958 | 107.074 |
+
+
+### Training Strategy
+#### Multi-stage Training
+In order to enhance the model's performance and harness the full potential of the original LLaMA-2, we have developed a multi-stage training strategy. This strategy is designed to systematically unlock the model's capabilities over a series of stages.
+
+Therefore, we have divided the training process into three stages:
+* Large-scale pre-training stage (Conducted by LLaMA-2): This initial stage is aimed at establishing the model's foundational capabilities from the ground up. It necessitates the use of a substantial dataset comprising no less than 1 trillion tokens.
+* Chinese knowledge injection stage: In this stage, we introduce Chinese knowledge into the model. It requires access to a high-quality dataset rich in comprehensive knowledge relevant to the Chinese language.
+* Knowledge replay stage: Knowledge is replayed through a question-answering (QA) mechanism, encompassing both the Chinese and English domains.
+
+Following the completion of this multi-stage training process, the model exhibits notable improvements in performance across both English and Chinese benchmarks.
+
+The following figure illustrates the three stages for training Colossal-LLaMA-2.
+
+
+
+
+
+#### Bucket-based Training
+Our experiments have revealed that the distributions within the training dataset, as well as the arrangement of various topic-related data points, significantly impact the overall performance of the model, particularly in the context of continual pre-training of LLaMA-2.
+
+In an effort to achieve a more balanced distribution and exert control over the dataset's ordering, we have adopted a method where we divide each sub-dataset into discrete bins. These bins are then combined to construct individual data buckets, with one bin contributed by each sub-dataset.
+
+### Bridging Any Domain-specific Large Models
+Applying the above process to perform knowledge transfer in any field allows for the cost-effective construction of lightweight domain-specific foundational large models.
+
+
+
+
+
+## Citations
+```bibtex
+@article{bian2021colossal,
+ title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training},
+ author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang},
+ journal={arXiv preprint arXiv:2110.14883},
+ year={2021}
+}
+```
+```bibtex
+@misc{touvron2023llama,
+ title={Llama 2: Open Foundation and Fine-Tuned Chat Models},
+ author={Hugo Touvron and Louis Martin and Kevin Stone and Peter Albert and Amjad Almahairi and Yasmine Babaei and Nikolay Bashlykov and Soumya Batra and Prajjwal Bhargava and Shruti Bhosale and Dan Bikel and Lukas Blecher and Cristian Canton Ferrer and Moya Chen and Guillem Cucurull and David Esiobu and Jude Fernandes and Jeremy Fu and Wenyin Fu and Brian Fuller and Cynthia Gao and Vedanuj Goswami and Naman Goyal and Anthony Hartshorn and Saghar Hosseini and Rui Hou and Hakan Inan and Marcin Kardas and Viktor Kerkez and Madian Khabsa and Isabel Kloumann and Artem Korenev and Punit Singh Koura and Marie-Anne Lachaux and Thibaut Lavril and Jenya Lee and Diana Liskovich and Yinghai Lu and Yuning Mao and Xavier Martinet and Todor Mihaylov and Pushkar Mishra and Igor Molybog and Yixin Nie and Andrew Poulton and Jeremy Reizenstein and Rashi Rungta and Kalyan Saladi and Alan Schelten and Ruan Silva and Eric Michael Smith and Ranjan Subramanian and Xiaoqing Ellen Tan and Binh Tang and Ross Taylor and Adina Williams and Jian Xiang Kuan and Puxin Xu and Zheng Yan and Iliyan Zarov and Yuchen Zhang and Angela Fan and Melanie Kambadur and Sharan Narang and Aurelien Rodriguez and Robert Stojnic and Sergey Edunov and Thomas Scialom},
+ year={2023},
+ eprint={2307.09288},
+ archivePrefix={arXiv},
+ primaryClass={cs.CL}
+}
+```
+```bibtex
+@article{dao2023flashattention2,
+ title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
+ author={Dao, Tri},
+ year={2023}
+}
+}
+```
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/__init__.py b/applications/Colossal-LLaMA-2/colossal_llama2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..56fafa58b3f43decb7699b93048b8b87e0f695aa
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/__init__.py
@@ -0,0 +1,2 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/__init__.py b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..56fafa58b3f43decb7699b93048b8b87e0f695aa
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/__init__.py
@@ -0,0 +1,2 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2cfb2ef6264b52801449ea72ccac0e1d1701bb5
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py
@@ -0,0 +1,219 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+import numpy as np
+import os
+import random
+from dataclasses import dataclass
+from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable
+
+import torch
+from datasets import dataset_dict, load_from_disk
+from datasets import Dataset as HFDataset
+from torch.distributed import ProcessGroup
+from torch.distributed.distributed_c10d import _get_default_group
+from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler
+from transformers.tokenization_utils import PreTrainedTokenizer
+import torch.nn.functional as F
+
+DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
+PathType = Union[str, os.PathLike]
+
+
+def load_tokenized_dataset(
+ dataset_paths: Union[PathType, List[PathType]], mode: str = "train"
+) -> Optional[DatasetType]:
+ """
+ Load pre-tokenized dataset.
+ Each instance of dataset is a dictionary with
+ `{'input_ids': List[int], 'labels': List[int], sequence: str}` format.
+ """
+ mode_map = {"train": "train", "dev": "validation", "test": "test"}
+ assert mode in tuple(mode_map), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}"
+
+ if isinstance(dataset_paths, (str, os.PathLike)):
+ dataset_paths = [dataset_paths]
+
+ datasets = [] # `List[datasets.dataset_dict.Dataset]`
+ for ds_path in dataset_paths:
+ ds_path = os.path.abspath(ds_path)
+ assert os.path.exists(ds_path), f"Not existed file path {ds_path}"
+ ds_dict = load_from_disk(dataset_path=ds_path, keep_in_memory=False)
+ if isinstance(ds_dict, HFDataset):
+ datasets.append(ds_dict)
+ else:
+ if mode_map[mode] in ds_dict:
+ datasets.append(ds_dict[mode_map[mode]])
+ if len(datasets) == 0:
+ return None
+ if len(datasets) == 1:
+ return datasets.pop()
+ return ConcatDataset(datasets=datasets)
+
+
+@dataclass
+class DataCollatorForSupervisedDataset(object):
+ """
+ Collate instances for supervised dataset.
+ Each instance is a tokenized dictionary with fields
+ `input_ids`(List[int]), `labels`(List[int]) and `sequence`(str).
+ """
+
+ tokenizer: PreTrainedTokenizer
+ max_length: int = 4096
+ ignore_index: int = -100
+
+ def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
+ """
+
+ Args:
+ instances (`Sequence[Dict[str, List[int]]]`):
+ Mini-batch samples, each sample is stored in an individual dictionary.
+
+ Returns:
+ (`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
+ `input_ids`: `torch.Tensor` of shape (bsz, max_len);
+ `attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
+ `labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`.
+ """
+ assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, (
+ f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, "
+ f"but now `{self.tokenizer.pad_token_id}`"
+ )
+
+ # `List[torch.Tensor]`
+ batch_input_ids = [
+ torch.LongTensor(instance["input_ids"][: self.max_length])
+ if len(instance["input_ids"]) > self.max_length
+ else torch.LongTensor(instance["input_ids"])
+ for instance in instances
+ ]
+ batch_labels = [
+ torch.LongTensor(instance["labels"][: self.max_length])
+ if len(instance["labels"]) > self.max_length
+ else torch.LongTensor(instance["labels"])
+ for instance in instances
+ ]
+
+ if self.tokenizer.padding_side == "right":
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ sequences=batch_input_ids,
+ batch_first=True,
+ padding_value=self.tokenizer.pad_token_id,
+ ) # (bsz, max_len)
+ labels = torch.nn.utils.rnn.pad_sequence(
+ sequences=batch_labels,
+ batch_first=True,
+ padding_value=self.ignore_index,
+ ) # (bsz, max_len)
+ # pad to max
+ to_pad = self.max_length - input_ids.size(1)
+ input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
+ labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
+ elif self.tokenizer.padding_side == "left":
+ reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids]
+ reversed_input_ids = torch.nn.utils.rnn.pad_sequence(
+ sequences=reversed_input_ids,
+ batch_first=True,
+ padding_value=self.tokenizer.pad_token_id,
+ ) # (bsz, max_len)
+ input_ids = torch.flip(reversed_input_ids, dims=(1,)) # (bsz, max_len)
+ reversed_labels = [seq.flip(dims=(0,)) for seq in batch_labels]
+ reversed_labels = torch.nn.utils.rnn.pad_sequence(
+ sequences=reversed_labels,
+ batch_first=True,
+ padding_value=self.ignore_index,
+ ) # (bsz, max_len)
+ labels = torch.flip(reversed_labels, dims=(1,)) # (bsz, max_len)
+ else:
+ raise RuntimeError(
+ f"`{self.tokenizer.__class__.__name__}.padding_side` can only be `left` or `right`, "
+ f"but now `{self.tokenizer.padding_side}`"
+ )
+
+ attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len)
+
+ return dict(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
+
+
+class StatefulDistributedSampler(DistributedSampler):
+ """
+ Stateful distributed sampler for multi-stage training.
+ """
+
+ def __init__(
+ self,
+ dataset: DatasetType,
+ num_replicas: Optional[int] = None,
+ rank: Optional[int] = None,
+ shuffle: bool = True,
+ seed: int = 0,
+ drop_last: bool = False,
+ ) -> None:
+ super().__init__(
+ dataset=dataset,
+ num_replicas=num_replicas,
+ rank=rank,
+ shuffle=shuffle,
+ seed=seed,
+ drop_last=drop_last,
+ )
+ self.start_index = 0
+
+ def __iter__(self) -> Iterator:
+ iterator = super().__iter__()
+ indices = list(iterator)
+ indices = indices[self.start_index :]
+ return iter(indices)
+
+ def __len__(self) -> int:
+ return self.num_samples - self.start_index
+
+ def set_start_index(self, start_index: int) -> None:
+ self.start_index = start_index
+
+
+def setup_distributed_dataloader(
+ dataset: DatasetType,
+ batch_size: int = 1,
+ shuffle: bool = False,
+ seed: int = 1024,
+ drop_last: bool = False,
+ pin_memory: bool = False,
+ num_workers: int = 0,
+ collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None,
+ process_group: Optional[ProcessGroup] = None,
+ **kwargs,
+) -> DataLoader:
+ """
+ Setup dataloader for distributed training.
+ """
+ _kwargs = kwargs.copy()
+ process_group = process_group or _get_default_group()
+ sampler = StatefulDistributedSampler(
+ dataset=dataset,
+ num_replicas=process_group.size(),
+ rank=process_group.rank(),
+ shuffle=shuffle,
+ seed=seed,
+ drop_last=drop_last,
+ )
+
+ # Deterministic dataloader
+ def seed_worker(worker_id: int) -> None:
+ worker_seed = seed
+ np.random.seed(worker_seed)
+ torch.manual_seed(worker_seed)
+ random.seed(worker_seed)
+
+ return DataLoader(
+ dataset=dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ num_workers=num_workers,
+ collate_fn=collate_fn,
+ pin_memory=pin_memory,
+ drop_last=drop_last,
+ worker_init_fn=seed_worker,
+ **_kwargs,
+ )
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c21f325ae6227a9df42da1aec6f1354e72da84b
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/spliced_and_tokenized_dataset.py
@@ -0,0 +1,183 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Splicing multiple pre-tokenized sequence data points
+"""
+
+import random
+import warnings
+from copy import deepcopy
+from datasets import dataset_dict
+from typing import Any, Callable, Dict, Iterable, List, Union, Tuple
+
+from torch.utils.data import ConcatDataset, Dataset, IterableDataset
+from transformers.models.llama.tokenization_llama import LlamaTokenizer
+from transformers.tokenization_utils import PreTrainedTokenizer
+
+IGNORE_INDEX = -100
+
+DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
+
+
+def supervised_tokenize(
+ data_point: Dict[str, str], tokenizer: LlamaTokenizer, ignore_index: int = None, max_length: int = 4096
+) -> Dict[str, Union[int, str, List[int]]]:
+ """
+ A tokenization function to tokenize an original pretraining data point as following:
+ {"source": "", "target": "Beijing, the capital of the People's Republic of China, ...", "category": "geography"}
+ """
+ assert tokenizer.add_bos_token is False and tokenizer.add_eos_token is False, (
+ "Initially set `tokenizer.add_bos_token` and `tokenizer.add_eos_token` to False, "
+ "add and manually later"
+ )
+ if ignore_index is None:
+ ignore_index = IGNORE_INDEX
+
+ source_text = data_point["source"] # `str`
+ target_text = data_point["target"] # `str`
+ is_null_source = len(source_text) == 0
+
+ source_text = tokenizer.bos_token + source_text
+ target_text += tokenizer.eos_token
+ sequence_text = source_text + target_text
+
+ tokenized = tokenizer([source_text, sequence_text])["input_ids"]
+ sequence_input_ids = tokenized[1]
+ sequence_labels = deepcopy(sequence_input_ids)
+
+ source_length = len(tokenized[0])
+ if not is_null_source:
+ sequence_labels[:source_length] = [ignore_index for _ in range(source_length)]
+
+ # sequence truncation.
+ if len(sequence_input_ids) > max_length:
+ sequence_input_ids = sequence_input_ids[:max_length]
+ sequence_labels = sequence_labels[:max_length]
+
+ return dict(
+ input_ids=sequence_input_ids,
+ labels=sequence_labels,
+ seq_length=len(sequence_input_ids),
+ seq_category=data_point["category"],
+ )
+
+
+class ClosedToConstantLengthSplicedDataset(IterableDataset):
+ """
+ Define an iterable dataset that returns a (close to) constant length data point spliced from multiple
+ original independent (pre-tokenized) data points.
+ """
+
+ def __init__(
+ self,
+ dataset: DSType,
+ tokenizer: PreTrainedTokenizer,
+ max_length: int = 4096,
+ num_packed_sequences: int = 8,
+ fetch_sequence_func: Callable[[Any], Tuple[List[int], List[int]]] = None,
+ input_ids_field: str = "input_ids",
+ labels_field: str = "labels",
+ infinite: bool = False,
+ shuffle: bool = True,
+ error_strict: bool = False,
+ ) -> None:
+ self.tokenizer = tokenizer
+ self.dataset = dataset
+ self.max_length = max_length
+ self.infinite = infinite
+ self.max_buffer_size = max_length * num_packed_sequences # e.g., 4096 * 16
+ self.shuffle = shuffle
+
+ # Callable[[Dict[str, Any]], Tuple[List[int], List[int]]],
+ # A function that fetch sequence input_ids and labels from the original data point
+ if fetch_sequence_func is None:
+ self.fetch_sequence_func = lambda data_point: (data_point[input_ids_field], data_point[labels_field])
+ else:
+ self.fetch_sequence_func = fetch_sequence_func
+ self.input_ids_field = input_ids_field
+ self.labels_field = labels_field
+
+ self.error_strict = error_strict
+ self.current_size = 0 # `int`, current packed data size.
+
+ def __len__(self) -> int:
+ return len(self.dataset)
+
+ def __iter__(self) -> Iterable[Dict[str, List[int]]]:
+ iterator = iter(self.dataset)
+ more_data_points = True
+ while more_data_points is True:
+ buffer, buffer_len = [], 0
+ while True:
+ # ending condition.
+ if buffer_len >= self.max_buffer_size:
+ break
+ try:
+ # `Tuple[List[int], List[int]]`
+ seq_input_ids, seq_labels = self.fetch_sequence_func(next(iterator))
+ buffer.append({self.input_ids_field: seq_input_ids, self.labels_field: seq_labels})
+ buffer_len += len(buffer[-1][self.input_ids_field])
+ except StopIteration:
+ if self.infinite is True:
+ iterator = iter(self.dataset)
+ warnings.warn("The dataset reached end and the iterator is reset to the start.")
+ else:
+ more_data_points = False
+ break
+ examples = [] # `List[Dict[str, List[int]]]`, save buffered spliced data points.
+ spliced_input_ids, spliced_labels = [], [] # `List[int]`, `List[int]`
+ for i, data_point in enumerate(buffer):
+ # TODO(2023-09-18) check errors for each unspliced tokenized data point
+ seq_input_ids = data_point[self.input_ids_field]
+ seq_labels = data_point[self.labels_field]
+ # Handle special case:
+ # If the length of an original data point (i.e., input_ids length of a data point before splicing)
+ # exceeds `max_length`, truncate it.
+ if len(seq_input_ids) > self.max_length:
+ truncated_seq_input_ids = seq_input_ids[: self.max_length]
+ truncated_label_ids = seq_labels[: self.max_length]
+ if set(truncated_label_ids) == {IGNORE_INDEX}:
+ if self.error_strict is True:
+ raise ValueError(
+ f"Find an out-of-bounds length({len(seq_input_ids)}) data point "
+ f"with all label values as {IGNORE_INDEX}."
+ )
+ else:
+ warnings.warn(f"Filter an error truncated data point (labels all {IGNORE_INDEX})")
+ continue # Skip the current error data point.
+ spliced_data_point = {
+ self.input_ids_field: truncated_seq_input_ids,
+ self.labels_field: truncated_label_ids,
+ }
+ examples.append(spliced_data_point)
+ warnings.warn("Find a data point to be truncated.")
+ continue
+
+ # Pre action judgment.
+ if len(spliced_input_ids) + len(seq_input_ids) > self.max_length:
+ spliced_data_point = {
+ self.input_ids_field: spliced_input_ids,
+ self.labels_field: spliced_labels,
+ } # `Dict[str, List[int]]`
+ # Update.
+ spliced_input_ids, spliced_labels = [], []
+ spliced_input_ids.extend(seq_input_ids)
+ spliced_labels.extend(seq_labels)
+ examples.append(spliced_data_point)
+ else:
+ spliced_input_ids.extend(seq_input_ids)
+ spliced_labels.extend(seq_labels)
+ # For residual spliced data point at the end of the data set
+ if self.infinite is False and more_data_points is False and len(spliced_input_ids) > 0:
+ examples.append(
+ {
+ self.input_ids_field: spliced_input_ids,
+ self.labels_field: spliced_labels
+ }
+ )
+ if self.shuffle:
+ random.shuffle(examples)
+ for spliced_data_point in examples:
+ # TODO(2023-09-18): check errors for each spliced tokenized data point.
+ self.current_size += 1
+ yield spliced_data_point
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py b/applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..67e487f43b082f0cac6a3466f711c79c132ec80d
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py
@@ -0,0 +1,111 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+"""
+Initialize new model with updated tokenizer by calculating the mean values from original model
+"""
+import argparse
+
+import numpy as np
+import torch
+from transformers import LlamaTokenizer, LlamaForCausalLM
+
+from colossalai.logging import get_dist_logger
+
+
+logger = get_dist_logger()
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--source_model_and_tokenizer_path",
+ type=str,
+ required=True,
+ default=None,
+ help="Source path of model & tokenizer",
+ )
+ parser.add_argument("--target_tokenizer_path", type=str, required=True, default=None, help="Target tokenizer path")
+ parser.add_argument("--target_model_path", type=str, required=True, default=None, help="Target model path")
+ args = parser.parse_args()
+
+ source_tokenizer = LlamaTokenizer.from_pretrained(args.source_model_and_tokenizer_path)
+ source_tokenizer.add_bos_token = False
+ source_tokenizer.add_eos_token = False
+ if source_tokenizer.pad_token is None:
+ source_tokenizer.pad_token = source_tokenizer.unk_token
+ source_vocab = source_tokenizer.get_vocab()
+
+ target_tokenizer = LlamaTokenizer.from_pretrained(args.target_tokenizer_path)
+ target_tokenizer.add_bos_token = False
+ target_tokenizer.add_eos_token = False
+ if target_tokenizer.pad_token is None:
+ target_tokenizer.pad_token = target_tokenizer.unk_token
+ target_vocab = target_tokenizer.get_vocab()
+ target_inverted_vocab = {v: k for k, v in target_vocab.items()}
+
+ assert len(target_vocab) > len(
+ source_vocab
+ ), f"Target vocab size({len(target_vocab)}) must be greater than source vocab size({len(source_vocab)})"
+
+ gpu_device = torch.device("cuda:0")
+ cpu_device = torch.device("cpu")
+
+ source_model = LlamaForCausalLM.from_pretrained(args.source_model_and_tokenizer_path)
+ source_model.eval()
+ source_model = source_model.to(gpu_device)
+
+ source_input_embeddings = source_model.get_input_embeddings()
+ assert isinstance(source_input_embeddings, torch.nn.Embedding)
+ assert source_input_embeddings.weight.shape[0] == len(source_vocab)
+ source_input_embeddings.eval()
+
+ source_output_embeddings = source_model.get_output_embeddings()
+ assert isinstance(source_output_embeddings, torch.nn.Linear)
+ assert source_output_embeddings.bias is None
+ assert source_output_embeddings.weight.shape[0] == len(source_vocab)
+ source_output_embeddings.eval()
+
+ input_embeddings = source_input_embeddings.weight.cpu().detach().numpy()
+ output_embeddings = source_output_embeddings.weight.cpu().detach().numpy()
+ for i in range(len(source_vocab), len(target_vocab)):
+ if i % 500 == 0:
+ logger.info(f"processing {i}/{len(target_vocab)} target tokens")
+ target_token = target_inverted_vocab[i]
+ target_to_source_token_ids = torch.LongTensor(source_tokenizer([target_token])["input_ids"][0])
+ target_to_source_token_ids = target_to_source_token_ids.to(gpu_device)
+
+ target_to_source_input_embedding = (
+ source_input_embeddings.weight[target_to_source_token_ids]
+ .mean(dim=0)
+ .unsqueeze(dim=0)
+ .cpu()
+ .detach()
+ .numpy()
+ )
+ target_to_source_output_embedding = (
+ source_output_embeddings.weight[target_to_source_token_ids]
+ .mean(dim=0)
+ .unsqueeze(dim=0)
+ .cpu()
+ .detach()
+ .numpy()
+ )
+
+ input_embeddings = np.concatenate((input_embeddings, target_to_source_input_embedding), axis=0)
+ output_embeddings = np.concatenate((output_embeddings, target_to_source_output_embedding), axis=0)
+
+ source_model = source_model.to(cpu_device)
+ assert isinstance(source_model, LlamaForCausalLM)
+
+ # expand
+ source_model.resize_token_embeddings(new_num_tokens=len(target_vocab))
+ source_model.model.embed_tokens.weight.data = torch.Tensor(input_embeddings)
+ source_model.lm_head.weight.data = torch.Tensor(output_embeddings)
+
+ source_model = source_model.half()
+ source_model.save_pretrained(save_directory=args.target_model_path)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py b/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..43297633db1a5b35142ddf04f0e0eb4b8a9d3ccb
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py
@@ -0,0 +1,98 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+"""
+Initialize new tokenizer for continual pre-training
+"""
+
+import argparse
+import os
+import json
+from typing import List, Union
+
+from transformers.models.llama.tokenization_llama import LlamaTokenizer
+from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
+
+from colossalai.logging import get_dist_logger
+
+os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
+
+logger = get_dist_logger()
+
+
+def expand_vocab_tokenizer(
+ source_tokenizer_dir: Union[str, os.PathLike], target_tokenizer_dir: Union[str, os.PathLike], new_tokens: List[str]
+) -> None:
+ """Expand tokenizer for continue pre-training."""
+ if os.path.exists(target_tokenizer_dir):
+ raise RuntimeError(f"Find existed directory {target_tokenizer_dir}")
+
+ source_tokenizer = LlamaTokenizer.from_pretrained(source_tokenizer_dir)
+ logger.info(source_tokenizer)
+ source_sp_processor = source_tokenizer.sp_model
+ source_spm = sp_pb2_model.ModelProto()
+ source_spm.ParseFromString(source_sp_processor.serialized_model_proto())
+
+ logger.info(f"Source tokenizer size: {len(source_sp_processor)}")
+
+ # Add new tokens to source tokenizer.
+ source_spm_tokens = set([p.piece for p in source_spm.pieces])
+ for piece in new_tokens:
+ assert isinstance(piece, str), f"Invalid token({piece}) type {type(piece)}"
+ if piece in source_spm_tokens:
+ # Skip existed token.
+ continue
+ new_p = sp_pb2_model.ModelProto().SentencePiece()
+ new_p.piece = piece
+ new_p.score = 0
+ source_spm.pieces.append(new_p)
+ logger.info(f"Expand vocab from {len(source_spm_tokens)} to {len(source_spm.pieces)}")
+
+ # Save
+ os.makedirs(target_tokenizer_dir)
+ target_tokenizer_model_path = os.path.join(target_tokenizer_dir, "tokenizer.model")
+ with open(file=target_tokenizer_model_path, mode="wb") as fp:
+ fp.write(source_spm.SerializeToString())
+
+ target_tokenizer = LlamaTokenizer(vocab_file=target_tokenizer_model_path)
+ target_tokenizer.save_pretrained(save_directory=target_tokenizer_dir)
+ logger.info(f"Successfully save expand tokenizer to {target_tokenizer_dir}")
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--source_tokenizer_dir", type=str, required=True, default=None, help="Source tokenizer directory"
+ )
+ parser.add_argument(
+ "--target_tokenizer_dir", type=str, required=True, default=None, help="Target tokenizer directory"
+ )
+ parser.add_argument(
+ "--expand_tokens_file",
+ type=str,
+ required=True,
+ default=None,
+ help="Path of the file containing tokens to be extended",
+ )
+ args = parser.parse_args()
+
+ expand_tokens = []
+ with open(file=args.expand_tokens_file, mode="r", encoding="utf-8") as fp_reader:
+ for line in fp_reader:
+ item = json.loads(line)
+ # e.g., {"piece": "你好"}
+ token = item["piece"]
+ if token in expand_tokens:
+ continue
+ expand_tokens.append(token)
+ expand_tokens.sort(key=lambda t: len(t), reverse=False)
+
+ expand_vocab_tokenizer(
+ source_tokenizer_dir=args.source_tokenizer_dir,
+ target_tokenizer_dir=args.target_tokenizer_dir,
+ new_tokens=expand_tokens,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/__init__.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..56fafa58b3f43decb7699b93048b8b87e0f695aa
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/__init__.py
@@ -0,0 +1,2 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py
new file mode 100644
index 0000000000000000000000000000000000000000..85decf37dd0b084197d9eeffcec77edb900365e0
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/ckpt_io.py
@@ -0,0 +1,88 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+"""
+Helper functions for IO
+"""
+
+import json
+import os
+from typing import Any, Dict, Tuple, Union
+
+import torch
+from torch.optim.optimizer import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+
+from colossalai.booster import Booster
+from colossalai.cluster import DistCoordinator
+
+
+def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
+ """
+ Load file in JSON format
+ """
+ with open(file=file_path, mode="r", encoding="utf-8") as fp:
+ return json.load(fp)
+
+
+def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None:
+ """
+ Save as JSON format
+ """
+ with open(file=file_path, mode="w", encoding="utf-8") as fp:
+ json.dump(data, fp=fp, ensure_ascii=False, indent=4)
+
+
+def save_checkpoint(
+ save_dir: Union[str, os.PathLike],
+ booster: Booster,
+ model: torch.nn.Module,
+ optimizer: Optimizer,
+ lr_scheduler: _LRScheduler,
+ epoch: int,
+ step: int,
+ batch_size: int,
+ coordinator: DistCoordinator,
+) -> None:
+ """
+ Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
+ """
+
+ save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
+ os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
+
+ booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
+
+ booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
+ booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
+ running_states = {
+ "epoch": epoch,
+ "step": step,
+ "sample_start_index": step * batch_size,
+ }
+ if coordinator.is_master():
+ save_json(running_states, os.path.join(save_dir, "running_states.json"))
+
+
+def load_checkpoint(
+ load_dir: Union[str, os.PathLike],
+ booster: Booster,
+ model: torch.nn.Module,
+ optimizer: Optimizer,
+ lr_scheduler: _LRScheduler,
+) -> Tuple[int, int, int]:
+ """
+ Load model checkpoint, optimizer, LR scheduler and intermedidate running states.
+ """
+
+ # Update booster params states.
+ booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling"))
+ booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
+ booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
+
+ running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
+ return (
+ running_states["epoch"],
+ running_states["step"],
+ running_states["sample_start_index"],
+ )
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c58c59307a6a6b2d1523e67633548edeb0e354e
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py
@@ -0,0 +1,216 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+from types import MethodType
+from typing import Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from transformers.models.llama.modeling_llama import (
+ LlamaRMSNorm,
+ LlamaAttention,
+ LlamaModel,
+ LlamaForCausalLM,
+ apply_rotary_pos_emb,
+ repeat_kv,
+)
+
+from colossalai.logging import get_dist_logger
+from einops import rearrange
+
+from flash_attn.bert_padding import pad_input, unpad_input
+from flash_attn.flash_attn_interface import (
+ flash_attn_func,
+ flash_attn_varlen_kvpacked_func,
+)
+from flash_attn.ops.rms_norm import rms_norm
+
+
+logger = get_dist_logger()
+
+
+def _prepare_decoder_attention_mask(
+ self: LlamaModel,
+ attention_mask: torch.BoolTensor,
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ past_key_values_length: int,
+) -> Optional[torch.Tensor]:
+ """
+ Decoder attetion mask
+ """
+ if past_key_values_length > 0 and attention_mask is not None:
+ attention_mask = torch.cat(
+ tensors=(
+ torch.full(
+ size=(input_shape[0], past_key_values_length),
+ fill_value=True,
+ dtype=attention_mask.dtype,
+ device=attention_mask.device,
+ ),
+ attention_mask,
+ ),
+ dim=-1,
+ ) # (bsz, past_key_values_length + q_len)
+ if attention_mask is not None and torch.all(attention_mask):
+ return None # Faster
+ return attention_mask
+
+
+def attention_forward(
+ self: LlamaAttention,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """
+ Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
+ """
+ if output_attentions:
+ logger.warning(
+ "Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
+ "return `None` instead."
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ if self.config.pretraining_tp > 1:
+ q_slicing, kv_slicing = (
+ dim // self.config.pretraining_tp
+ for dim in (
+ self.num_heads * self.head_dim,
+ self.num_key_value_heads * self.head_dim,
+ )
+ ) # `Tuple[int, int]`
+ q_slices, k_slices, v_slices = (
+ proj.weight.split(slicing, dim=0)
+ for proj, slicing in (
+ (self.q_proj, q_slicing),
+ (self.k_proj, kv_slicing),
+ (self.v_proj, kv_slicing),
+ )
+ ) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
+ q, k, v = (
+ torch.cat(
+ [F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)],
+ dim=-1,
+ )
+ for slices in (q_slices, k_slices, v_slices)
+ )
+ # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
+ # (bsz, q_len, num_heads * head_dim),
+ # (bsz, q_len, num_key_value_heads * head_dim),
+ # (bsz, q_len, num_key_value_heads * head_dim)
+ else:
+ q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj))
+ # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
+ # (bsz, q_len, num_heads * head_dim),
+ # (bsz, q_len, num_key_value_heads * head_dim),
+ # (bsz, q_len, num_key_value_heads * head_dim)
+
+ # (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim);
+ # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim);
+ # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim)
+ q, k, v = (
+ states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
+ for states, num_heads in (
+ (q, self.num_heads),
+ (k, self.num_key_value_heads),
+ (v, self.num_key_value_heads),
+ )
+ )
+ kv_len = k.shape[-2] # initially, `kv_len` == `q_len`
+ past_kv_len = 0
+ if past_key_value is not None:
+ # if `past_key_value` is not None, `kv_len` > `q_len`.
+ past_kv_len = past_key_value[0].shape[-2]
+ kv_len += past_kv_len
+
+ # two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim)
+ cos, sin = self.rotary_emb(v, seq_len=kv_len)
+ # (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
+ q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids)
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ k = torch.cat([past_key_value[0], k], dim=2)
+ v = torch.cat([past_key_value[1], v], dim=2)
+
+ past_key_value = (k, v) if use_cache else None
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups)
+ # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
+ v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups)
+ # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
+
+ key_padding_mask = attention_mask
+ # (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim)
+ q, k, v = (states.transpose(1, 2) for states in (q, k, v))
+
+ if past_kv_len > 0:
+ q = torch.cat(
+ tensors=(
+ torch.full(
+ size=(bsz, past_kv_len, self.num_heads, self.head_dim),
+ fill_value=0.0,
+ dtype=q.dtype,
+ device=q.device,
+ ),
+ q,
+ ),
+ dim=1,
+ ) # (bsz, past_kv_len + q_len, num_heads, head_dim)
+
+ if key_padding_mask is None:
+ # (bsz, past_kv_len + q_len, num_heads, head_dim)
+ output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, )
+ output = rearrange(output, pattern="... h d -> ... (h d)") # (bsz, past_kv_len + q_len, num_heads * head_dim)
+ else:
+ q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
+ kv, _, cu_kv_lens, max_kv_len = unpad_input(
+ hidden_states=torch.stack(tensors=(k, v), dim=2),
+ attention_mask=key_padding_mask,
+ )
+ output_unpad = flash_attn_varlen_kvpacked_func(
+ q=q,
+ kv=kv,
+ cu_seqlens_q=cu_q_lens,
+ cu_seqlens_k=cu_kv_lens,
+ max_seqlen_q=max_q_len,
+ max_seqlen_k=max_kv_len,
+ dropout_p=0.0,
+ softmax_scale=None,
+ causal=True,
+ )
+ output = pad_input(
+ hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"),
+ indices=indices,
+ batch=bsz,
+ seqlen=past_kv_len + q_len,
+ ) # (bsz, past_kv_len + q_len, num_heads * head_dim)
+
+ if past_kv_len > 0:
+ # Strip off the zero query outputs.
+ output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim)
+ output = self.o_proj(output) # (bsz, q_len, hidden_size)
+ return output, None, past_key_value
+
+
+def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor:
+ """
+ Formard function for RMS Norm
+ """
+ return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon)
+
+
+def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
+ for name, module in model.named_modules():
+ if isinstance(module, LlamaAttention):
+ module.forward = MethodType(attention_forward, module)
+ if isinstance(module, LlamaModel):
+ module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module)
+ if isinstance(module, LlamaRMSNorm):
+ module.forward = MethodType(rms_norm_forward, module)
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/froze.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/froze.py
new file mode 100644
index 0000000000000000000000000000000000000000..82677160d868301b357f83241fd4ae1592d0b841
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/froze.py
@@ -0,0 +1,18 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+from transformers.models.llama import LlamaForCausalLM
+
+
+def freeze_non_embeds_parameters(model: LlamaForCausalLM) -> None:
+ """Freeze all parameters except embeddings."""
+ for name, params in model.named_parameters():
+ if "embed_tokens" not in name and "lm_head" not in name:
+ params.requires_grad = False
+ else:
+ params.requires_grad = True
+
+
+def unfreeze_parameters(model: LlamaForCausalLM) -> None:
+ for name, params in model.named_parameters():
+ params.requires_grad = False
diff --git a/applications/Colossal-LLaMA-2/docs/example.md b/applications/Colossal-LLaMA-2/docs/example.md
new file mode 100644
index 0000000000000000000000000000000000000000..d889ab4165d0f676fecb8a60e83a6163757470ce
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/docs/example.md
@@ -0,0 +1,245 @@
+# Colossal-LLaMA-2-7B-base Examples
+To comprehensively assess the performance of the Colossal-LLaMA-2-7B-base model, our team conducted human evaluations across various knowledge domains and tasks. These tasks encompassed Knowledge QA in 10 different areas, Content Generation, Brainstorming, Summarization, Sentiment Analysis, Logical Reasoning, Information Extraction, Role-play, Classification, and Rewriting. We also conducted a comparative analysis, pitting the generation results of Colossal-LLaMA-2-7B-base against Qwen-7b-base, a recently-pretrained model known for its impressive performance. It's worth noting that both models shared identical inference parameters. For each question, we generated model responses multiple times and selected the best one as the final answer.
+
+
+
+## Table of Contents
+- [Examples](#knowledge-qa)
+ - [Knowledge QA](#knowledge-qa)
+ - [Travel](#travel)
+ - [Medical](#medical)
+ - [Food](#food)
+ - [Computer Science](#computer-science)
+ - [Game](#game)
+ - [Sports](#sports)
+ - [Finance](#finance)
+ - [Law](#law)
+ - [Technology](#technology)
+ - [Literature](#literature)
+ - [Content Generation](#content-generation)
+ - [Writing](#writing)
+ - [Code](#code)
+ - [E-mail](#e-mail)
+ - [LaTex Table](#latex-table)
+ - [Brainstorming](#brainstorming)
+ - [Summarization](#summarization)
+ - [Sentiment Analysis](#sentiment-analysis)
+ - [Logical Reasoning](#logical-reasoning)
+ - [Information Extraction](#information-extraction)
+ - [Role-play](#role-play)
+ - [Classification](#classification)
+ - [Rewriting](#rewriting)
+- [Conclusion](#conclusion)
+
+
+### Knowledge QA
+#### Travel
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 中国最适合家庭度假的旅游目的地是哪里? | 1. 中国最适合家庭度假的旅游目的地是海南。海南是中国唯一的热带岛屿省份,素有“东方夏威夷”之称。2. 三亚。三亚市位于海南岛的最南端,北与东方市交界,东与陵水县相邻,南与乐东县接壤,西濒南海,是一个风光绮丽的热带滨海旅游城市。 | 中国有非常多的适合家庭度假的目的地,以下是一些推荐:
1. 海南岛:海南岛有得天独厚的气候和风景,是一个度假的好去处,不仅有美丽的海滩和沙滩,还有很多可以游玩的景点,比如南山寺、五指山等。
2. 三亚:三亚是一个受欢迎的旅游城市,有美丽的海滩和浪漫的夕阳,还有丰富的海鲜美食。您可以在这里参加水上活动,比如滑水、帆板等。
3. 青岛:青岛是一个充满活力和美丽的城市,有美丽的海滩、博物馆、美食和啤酒,可以带孩子们一起去海边玩沙子,也可以去博物馆参观。
4. 成都:成都是中国西南地区的一个大城市,不仅有丰富的历史和文化,还有美食和美景。您可以去到大熊猫基地和锦里古街,还可以品尝到地道的川菜。
5. 长沙:长沙是一个历史悠久的城市,有丰富多彩的文化和历史景点,比如岳麓山、橘子洲头和岳麓书院等。此外,长沙的美食也是非常有名的,您可以在长沙尝到辣子鸡等当地美食。
以上是几个适合家庭度假的目的地,希望能给您提供一些参考。 |
+| 请问武汉的最佳的旅行季节是什么时候? | 3-4月份到武汉赏樱,10月份到武汉看黄陂枫树。武汉的最佳的旅行季节是春秋,3-4月份到武汉赏樱,10月份到武汉看黄陂枫树。武汉,简称“汉”别称江城,是湖北省省会,中部六省唯一的副省级市,特大城市,中国中部地区的中心城市,全国重要的工业基地、科教基地和综合交通枢纽。 | Assistant:
武汉的最佳旅行季节是春季和秋季。春季气温宜人,可以赏樱花,秋季天气宜人,可以欣赏红叶。夏季温度较高,游客众多,冬季则有雪景和温泉可赏,但冬季的交通和住宿可能会更加拥挤。 |
+| What are amazing islands near Singapore that could be visited? | Singapore is one of the most developed and well-connected cities in the world, with an impressive skyline and modern infrastructure. However, there are many amazing islands near Singapore that are worth visiting. Here are some of the most popular and beautiful islands near Singapore:
1. Sentosa Island: Sentosa Island is a small island located off the southern coast of Singapore. It is a popular tourist destination, known for its beaches, theme parks, and lush greenery. Visitors can enjoy a range of activities on the island, including sunbathing, swimming, and exploring the jungles.
2. Pulau Ubin: Pulau Ubin is a small island located about 45 minutes away from Singapore by boat. It is a popular destination for nature lovers, offering hiking trails, cycling paths, and a variety of wildlife. Visitors can also explore the fishing village and try their hand at fishing.
3. Pulau Tekong: Pulau Tekong is a small island located about 30 minutes away from Singapore by boat. It is a former military base, but is now a popular destination for nature lovers. Visitors can enjoy hiking trails, cycling paths, and a variety of wildlife.
4. Pulau Sibuyu: Pulau Sibuyu is a small island located about 20 minutes away from Singapore by boat. It is a popular destination for water sports enthusiasts, offering snorkeling, diving, and kayaking. Visitors can also explore the fishing village and try their hand at fishing.
5. Pulau Rapa: Pulau Rapa is a small island located about 20 minutes away from Singapore by boat. It is a popular destination for birdwatchers, offering a variety of bird species and birdwatching spots. Visitors can also explore the fishing village and try their hand at fishing.
Overall, these are some of the amazing islands near Singapore that are worth visiting. Whether you are looking for adventure, relaxation, or simply want to explore the natural beauty of Singapore, these islands offer something for everyone. | "What are some good islands near Singapore with a short distance? Which one is recommended?"
Please tell me if those questions are the same.
OPTIONS:
- no
- yes
no |
+| Can you recommend local place to eat breakfast and dinner in Shanghai? | Yes, I can certainly recommend some great local places to eat breakfast and dinner in Shanghai. Here are some suggestions:
1. Dumpling House: This popular restaurant serves delicious dumplings and other Chinese dishes for breakfast and dinner. They have multiple locations throughout Shanghai, so you can easily find one near you.
2. Yuyuan Brewers Fresh Beer Restaurant: This unique restaurant offers a variety of fresh beers and delicious food, including breakfast and dinner options. It's a fun and relaxed atmosphere to enjoy your meal.
3. Zenbox Coffee & Bakery: This trendy coffee shop is the perfect spot to start your day with a delicious breakfast sandwich or pastry. They also have a great selection of lunch options, so you can continue enjoying your meal in the afternoon.
4. Shanghai Taste: This authentic Shanghai restaurant offers traditional Chinese dishes, including noodles and stir-fry dishes, for breakfast and dinner. The cozy atmosphere and friendly service make it a great choice for a meal.
5. XLR8 Family Cafe: This family-friendly cafe offers breakfast and dinner options, including a healthy salad bar. They also have a kids' menu with plenty of options for little ones.
These are just a few examples of the many great places to eat breakfast and dinner in Shanghai. With so many options to choose from, you're sure to find something that suits your taste and budget. | I am traveling on budget and will have access to public transportation to reach the location.
Assistant: Sure! Here are some local places to eat breakfast and dinner in Shanghai that you might want to consider:
Breakfast:
1. Fresh on the Square - They have a variety of breakfast options, including sandwiches, smoothies, and porridge. The prices are affordable, and they have a location at People's Square.
2. The Dumpling House - This restaurant is known for its delicious dumplings and offers an extensive menu of Chinese breakfast items, such as baozi, soy milk, and jianbing. The location is near the Jing'an Temple subway station.
3. Shanghainese Roast Duck - This restaurant is famous for its crispy duck, but they also have breakfast items like rice porridge, nian gao (sticky rice cake), and soy milk. The location is near the Xinzhuang Road subway station.
Dinner:
1. Lost Heaven - This restaurant offers a fusion of Chinese and Western cuisine in a beautiful garden setting. Prices are on the higher side, but the quality of food is excellent.
2. Shanghainese Restaurant - This restaurant is a must-try if you want to experience authentic Shanghai food. They have a variety of Shanghainese dishes, such as hairy crab, xiao long bao (steamed soup dumplings), and braised pork belly. The location is near the Hongqiao Railway Station.
3. Lao Da Zhou - This restaurant is known for its delicious xiao long bao and other Shanghainese dishes. Prices are reasonable, and they have a location near the Zhongshan Park subway station.
I hope this helps you plan your budget-friendly food itinerary in Shanghai! |
+
+
+
+#### Medical
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 甲状腺手术后多久可以干活? | 甲状腺手术后三个月左右可以干活。甲状腺手术要根据患者的病情和手术方式来决定,如果患者的症状较轻,病情比较轻微,选择的是微创手术,一般手术后一个月就可以干活。若患者的症状比较严重,病情得不到控制,需要对甲状腺部位进行常规手术来控制病情,这种手术的创伤比较大,可能需要两个月左右才能完全恢复,恢复后才能干活。平时要注意不能让甲状腺部位受到外力刺激,以免发生疼痛和出血症状,饮食要以清淡为主,尽量选择易消化的食物来帮助身体恢复,避免摄入一些辛辣刺激类的食品,要多吃蔬菜水果,多喝温水。 | 这个问题需要根据具体情况而定。一般来说,甲状腺手术后需要至少休息2-4周,期间不能剧烈运动,同时还要注意术后恢复和饮食调理,以保证身体恢复。在医生的指导下,逐渐恢复工作和生活。建议术后恢复期间多休息,避免过度疲劳,多食水果、蔬菜等清淡食物,保持良好的心态,定期复查,以确保术后恢复良好。 |
+| 喉咙发炎引起的发烧38.5摄氏度,怎么办? | 喉咙发炎导致发烧38.5摄氏度,应该要根据实际病情采取以下几种措施:1.对于发热38.5摄氏度,要及时遵医嘱给予退热药物治疗,例如布洛芬胶囊、尼美舒利片等,同时多喝水,有利于体温的快速下降。2.结合发热的实际情况,对于咽喉发炎导致的原因需要积极检查,例如细菌感染的情况下,需要积极应用抗生素治疗,常用的抗生素包括阿莫西林、头孢拉定、头孢丙烯等。3.在饮食上需要注意,不能吃辛辣刺激性的食物,例如辣椒等。需要多吃一些清淡易消化的食物,例如面条、青菜、豆腐、小米粥等。4.如果嗓子痛的比较严重,可以选择布地奈德混悬液或者地塞米松注射液进行雾化吸入治疗。以上方案仅供参考,具体药品使用请结合自身情况在专业医生指导下用药。 | 请医生帮忙。
您好,根据您提供的症状,建议您尽快去医院就诊,让医生进行详细的检查和诊断。对于喉咙发炎引起的发烧,可以采取以下措施缓解症状:
1. 多喝水:保持充足的水分摄入可以缓解喉咙疼痛和喉咙发炎的症状。
2. 休息:多休息可以帮助身体恢复,缓解疲劳和疼痛。
3. 温水漱口:用温水漱口可以缓解喉咙疼痛和喉咙发炎的症状。
4. 温度适宜:保持室内温度适宜,避免过度受凉。
5. 吃清淡易消化的食物:避免辛辣、油腻、刺激性食物,多吃清淡易消化的食物。
请注意,以上方法只是缓解症状的方法,不能替代医生的诊断和治疗。建议您尽快就医,以便早日恢复健康。 |
+| How can one treat intense pain in the stomach while suffering from ulcer? | Intense pain in the stomach is a common symptom of ulcers, and it can be quite uncomfortable and painful. It is essential to seek medical attention as soon as possible to obtain an accurate diagnosis and appropriate treatment. Treatment for ulcer-related pain usually involves managing pain with medication, addressing the underlying cause of the ulcer, and maintaining a healthy lifestyle.
1. Manage pain with medication:
There are several over-the-counter pain relievers available that can help alleviate symptoms. Examples of these medications include nonsteroidal anti-inflammatory drugs (NSAIDs), such as ibuprofen and naproxen, as well as acetaminophen (Tylenol). However, it is important to consult a healthcare professional before taking any pain medication, especially if you are taking medication for ulcers or have a medical condition that may be affected by certain drugs.
2. Address the underlying cause of the ulcer:
Ulcers can be caused by a variety of factors, including bacterial infections, H. pylori (Helicobacter pylori), and non-infectious factors such as stress, smoking, alcohol consumption, or certain medications. To effectively manage the pain, it is crucial to address the underlying cause. In cases of H. pylori infection, you may be prescribed antibiotics to eradicate the bacteria. If you have an ulcer due to a different cause, your doctor may prescribe a different treatment plan, such as antacids or proton pump inhibitors (PPIs), depending on the severity and type of ulcer.
3. Maintain a healthy lifestyle:
To promote healing and reduce the risk of future ulcers, it is essential to maintain a healthy lifestyle. Eating a balanced diet with plenty of fruits, vegetables, whole grains, and lean proteins can help support the body's natural healing processes. Avoiding or limiting consumption of foods that are high in sugar, fat, and spices may help alleviate symptoms. Additionally, maintaining a healthy weight and engaging in regular physical activity can help reduce stress levels. | I am suffering from ulcer since 5 months. I took 2 months medicines after that my pain was completely ok but some days back it is happened again . I also take medicines which doctor have prescribed. Now I am suffering from stomach pain and acidity. Please tell me what to do and what can't I do in this situation?
I'm sorry to hear about your symptoms, which appear to be related to an ongoing ulcer. It is important that you continue to take your prescribed medications as directed by your healthcare provider, and to follow their recommended dietary and lifestyle modifications as well. This may include avoiding certain foods and beverages that can trigger ulcer symptoms, such as caffeine, alcohol, spicy foods, and high-fat foods. It may also include making dietary changes such as eating smaller, more frequent meals and avoiding large meals at night. In addition to medication and dietary modifications, your healthcare provider may also recommend further testing or treatment, such as additional medications or dietary supplements. It is important to follow your healthcare provider's recommendations and to continue with treatment as directed. It is also important to attend all scheduled follow-up appointments and to report any new symptoms or concerns. |
+
+
+
+#### Education
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 家庭教育在孩子的教育中扮演什么角色? | 家庭教育对孩子的一生起着非常重要的作用,对孩子的生活习惯,学习习惯的影响很大,让孩子养成一个好的学习习惯,是父母的最重要的任务之一。 | 1. 培养孩子的性格和价值观
2. 给孩子提供正确的信息和教育
3. 建立良好的亲子关系
4. 促进孩子的社交能力
5. 培养孩子的独立性和责任感
6. 提供适当的支持和指导
7. 鼓励孩子的兴趣和才能
8. 帮助孩子学会管理时间和任务
9. 培养孩子的道德观念和公民意识
10. 与孩子进行开放的交流和互动 |
+| 如何提高教育的质量? | 1、转变教育观念、调整教育结构、优化教育方法。2、培养正确的人生观和价值观,是搞好教育的题中之义,是扭转子女教育的偏失的关键。3、教师要熟练业务,忠于职守。这是提高教师教育质量的前提,是保证教育有效实施的关键。4、加强教师思想政治教育和职业道德建设。 | 教育质量的提高是一个复杂的过程,需要采取多种策略和措施。以下是一些可能提高教育质量的建议:
1. 建立良好的教育体系:一个良好的教育体系应该有清晰的教育目标、清晰的教学计划、高质量的师资力量和科学的教育评估体系。
2. 强调学生的全面发展:教育应该注重培养学生的全面素质,不仅包括知识技能的传授,还包括品德、道德、文化、艺术等方面的培养。
3. 采用创新的教学方法:创新的教学方法可以激发学生的兴趣和学习积极性,例如案例教学、项目教学、探究式教学等。
4. 建立与社会需求相适应的教育内容:教育内容应该适应社会需求的变化,例如对于新兴行业和科技领域的关注和研究。
5. 加强教师培训和提高:教师是教育质量的关键因素,因此应该加强教师培训和提高,提高教师的教学能力和职业素养。
6. 建立多元化的教育模式:不同学生有不同的学习需求和方式,因此应该建立多元化的教育模式,例如在线教育、远程教育等。
7. 强化教育评估:教育评估可以提供教育质量改进的依据和指导,应该建立科学的教育评估体系,定期对教育质量进行评估和改进。 |
+| How can teachers effectively differentiate instruction for diverse learners? | Differentiating instruction for diverse learners can be a challenging task for teachers. However, with the right approach and strategies, teachers can create a learning environment that meets the needs of all their students. Here are some tips on how to effectively differentiate instruction for diverse learners:
1. Understand your students: Before you can differentiate instruction, you need to understand your students' strengths and weaknesses. Spend time getting to know your students and their individual needs. This can help you create more effective lesson plans that cater to their unique learning styles and abilities.
2. Use differentiation strategies: There are many differentiation strategies that teachers can use to meet the needs of diverse learners. Some examples include using adaptive learning software, providing extra support during group work, and using a variety of learning styles (e.g. auditory, visual, and kinesthetic).
3. Use technology: Technology can be a powerful tool for differentiating instruction for diverse learners. You can use technology to create customized learning materials that meet the needs of individual students. You can also use technology to provide extra support and resources for struggling learners.
4. Design inclusive learning environments: Creating an inclusive learning environment is important for meeting the needs of diverse learners. This means creating a space where all students feel welcome and valued. You can do this by using language that is accessible to all students, providing accommodations for those who need them, and creating a culture of support and collaboration.
5. Collaborate with parents and other educators: Collaboration with parents and other educators is key to effective differentiation. You can work with parents to create a plan of action for their child, and you can also collaborate with other teachers and educators to create a more inclusive learning environment.
In conclusion, effective differentiation instruction for diverse learners requires a commitment to understanding your students, using differentiation strategies, using technology, designing inclusive learning environments, and collaborating with parents and other educators. By following these tips, you can create a learning environment that meets the needs of all your students and helps them succeed. | Provide specific strategies and examples of how teachers can personalize their lessons to meet the individual needs of students
Assistant: Differentiating instruction means tailoring teaching and learning activities to meet the unique needs of each learner in the classroom. Teachers can effectively differentiate instruction for diverse learners by using various strategies, including the following:
1. Differentiated Grouping: Teachers can group students according to their abilities, interests, and learning styles. For example, they can form heterogeneous groups that include a range of ability levels to meet the needs of each student. Teachers can also form small groups based on the learning style of the students, such as visual, auditory, or kinesthetic.
2. Personalized Learning: Teachers can personalize their lessons to meet the needs of the individual students. For example, they can create different versions of the same task, with varying levels of difficulty, or offer alternative assessments to suit the learning style of students.
3. Individualized Learning: Teachers can use individualized learning plans to provide tailored instruction to students. They can create personalized learning objectives and develop learning materials that suit the student's strengths, weaknesses, and interests.
4. Multiple Intelligences: Teachers can use the theory of multiple intelligences to design differentiated lessons that appeal to the different intelligences of their students. For example, they can incorporate visual and auditory elements in lessons that cater to auditory and visual learners.
5. Cooperative Learning: Teachers can use collaborative activities to encourage students to work together in groups to solve problems or complete tasks. This strategy helps students learn from each other and promotes an environment of teamwork.
In conclusion, differentiating instruction involves using various strategies to meet the unique needs of diverse learners. Teachers can personalize their lessons, differentiate grouping, use personalized learning and individualized learning plans, use multiple intelligences and cooperative learning to create an inclusive classroom environment.
Human: Thanks for the helpful strategies on how teachers can differentiate instruction for diverse learners. Can you provide more examples of how teachers can incorporate visual elements in lessons to cater to auditory and visual learners
Assistant: Certainly! Here are some examples of how teachers can incorporate visual elements in lessons to cater to auditory and visual learners:
1. Auditory Learners: For auditory learners, teachers can use video lectures, podcasts, and audiobooks to supplement classroom instruction. Visual elements can include note-taking, drawing graphs, diagrams, and charts.
2. Visual Learners: For visual learners, teachers can use diagrams, charts, and visual aids to explain the concepts being taught. Examples of visual aids include: 3D models, animations |
+
+
+
+#### Food
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 请介绍中国的八大菜系和给出每个菜系最有代表性的菜。 | 1.鲁菜——醋椒鱼、扒肉条、一品豆腐;2.川菜——麻婆豆腐、酸菜鱼、宫保鸡丁;3.粤菜——烧鹅、烧鸭、白切鸡;4.苏菜——叫花鸡、松鼠桂鱼、水晶虾仁;5.浙菜——东坡肉、西湖醋鱼、东坡肉;6.湘菜——剁椒鱼头、辣椒炒肉、辣椒炒鸡;7.闽菜——佛跳墙、荔枝肉、太极芋泥;8.徽菜——徽州毛豆腐、臭鳜鱼、徽州一品锅。 | 中国的八大菜系是:
1. 鲁菜:红烧肉、糖醋里脊、宫保鸡丁
2. 粤菜:烧鹅、清蒸鲈鱼、叉烧
3. 川菜:宫爆鸡丁、麻婆豆腐、水煮鱼
4. 苏菜:松鼠鳜鱼、东坡肉、清炒虾仁
5. 浙菜:西湖醋鱼、龙井虾仁、叫化鸡
6. 闽菜:佛跳墙、醉排骨、荔枝肉
7. 徽菜:黄山炖鸽、徽州肉、黄山熏肉
8. 湘菜:剁椒鱼头、口味虾、酸菜鱼
希望这些能帮助您更好地了解中国的八大菜系。 |
+| 用牛肉、土豆、萝卜、木薯、罗勒叶、百里香、苹果生成一个西餐食谱。 | 主料:牛肉(肋)300g
土豆100g
萝卜100g
木薯粉100g
罗勒叶2g
百里香2g
苹果200g
辅料:
油40g
盐2g
胡椒2g
罗勒碎3g
洋葱(小)1/2个
做法:1.牛肉切成粒用黄油、胡椒粉和盐抓匀,静置半小时。罗勒和百里香切碎。土豆去皮切块,用黄油和盐炒香。
炒好的土豆装起备用。
锅中倒入黄油,放入洋葱碎。
炒香后加入牛肉。
炒5分钟后加入罗勒碎和百里香碎翻炒。
加入罗勒和百里香碎翻炒。
炒香后加入萝卜块。
炒10分钟后倒入炒好的土豆。
翻炒均匀后加入适量的水,转小火炖煮。
待水烧干后倒入木薯粉。
转小火慢慢熬煮。
熬煮至可以拉成丝即可关火。
把苹果削皮去核,切块。
把煮好的牛肉洋葱苹果土豆泥摆盘上桌。
再在旁边点缀上柠檬。 | 您好!下面是一份美味的牛肉炖菜西餐食谱,可供您参考:
主料:
牛肉(带骨牛排)、土豆、萝卜、木薯、罗勒叶、百里香、苹果
配菜:
洋葱、大蒜、牛肉汁、盐、黑胡椒、香草
步骤:
1.将牛肉洗净备用。
2.洋葱和大蒜切碎,木薯和萝卜切成块状,罗勒叶和百里香洗净备用。
3.热锅凉油,将洋葱和大蒜煸炒至香,放入牛肉煸炒至变色。
4.加入牛肉汁、盐、黑胡椒、香草等调料,煮沸后转小火慢炖30分钟。
5.加入木薯和萝卜块,继续煮15分钟。
6.加入苹果块,煮5分钟。
7.将蔬菜放入碗中,配上热腾腾的牛肉和汤汁即可。
祝您用餐愉快! |
+| How to make a classic French omelette? | This traditional French dish combines butter, eggs, and seasoning. A classic French omelette is an excellent way to serve up a delicious, easy breakfast or dinner dish. Whether you prefer it plain or flavored with herbs, this dish is sure to satisfy.
## Prepare the ingredients and cooking equipment
1. Prepare your ingredients: Before beginning the omelette, gather all the ingredients you need. Your ingredients should include butter, flour, eggs, and water.
2. Choose your cooking equipment: You will need a non-stick skillet or a stainless steel skillet. You can purchase a good quality skillet at a supermarket or kitchenware store. However, a great skillet isn’t limited to French omelettes. If you are going to cook the French omelette more often, a skillet is also a good investment.
3. Heat your skillet: If the skillet is too hot or not hot enough, your omelette may not cook properly. Add a small amount of butter to your skillet and swirl it to heat it. When the butter begins to bubble, your skillet is ready for cooking.
4. Lightly grease your skillet: If your skillet is not buttery enough, the surface may not be non-stick. Using a paper towel, lightly grease your skillet before cooking. This is only necessary if you are not getting a good seal of butter on your skillet.
## Cook the French omelette
1. Crack an egg: Crack an egg into a small bowl or measuring cup. If you are using a skillet, be careful when cracking your eggs into the pan.
2. Whisk the egg: Use a fork or whisk to whisk the egg. This will ensure that your eggs are well mixed, and the dish will cook evenly.
3. Add the water: Gently stir the water into the whisked egg. Add only enough water to moisten the eggs slightly.
4. Add a little salt and pepper: Add salt and pepper to taste. Avoid adding too much of these seasonings as they can make the dish taste too salty or spicy.
5. Slow | Break 6 eggs into a bowl, add a pinch of salt and pepper. Whisk together. Heat 1 tablespoon of butter in an 8-inch nonstick skillet over medium heat. When the foam subsides, add 1/4 cup of the eggs to the skillet. Tilt the pan so the eggs cover the bottom. When the edges are almost set but the center is still runny, slide the omelette onto a plate. Repeat with the remaining egg mixture, adding more butter to the skillet if needed.
3. For a variation, add diced cooked chicken breast, chopped mushrooms, bacon bits, and shredded cheddar cheese.
4. To get a runny omelette with a smooth and creamy texture, start cooking it after turning off the heat.
How about this one: How to make a classic Italian pasta bolognese?
1. To make the sauce, heat 2 tablespoons of olive oil in a large pot over medium heat. Add 1 small chopped onion and 3 minced cloves of garlic. Cook until softened.
2. Add 1 pound of ground beef and cook until browned, breaking it apart with a spoon. Season with salt, pepper, and dried thyme.
3. Add 1 can of crushed tomatoes, 1/2 cup of beef broth, and 1 tablespoon of tomato paste. Simmer for 15 minutes.
4. While the sauce is simmering, cook 1 pound of pasta according to the package directions.
5. Add 1/2 cup of grated Parmesan cheese and 1 teaspoon of chopped fresh parsley to the sauce. Toss with the pasta and serve with more cheese on top. |
+| Recommend top 5 best steak restaurants in the Asia. | Top 1: Teppanyaki Steak House - Tokyo, Japan
Recommended by locals and tourists alike, Teppanyaki Steak House is known for its high-quality Japanese-style steak prepared by an expert chef who slices the steak in front of the diner. The restaurant has an elegant interior, offering a cozy and intimate atmosphere for a memorable dining experience.
Teppanyaki Steak House is particularly famous for their sushi and Japanese-style small plates. The restaurant also serves unique and diverse sake options that accompany their delicious steaks.
Top 2: Angus Bar - Bangkok, Thailand
Located in the upscale area of Sathorn, Angus Bar is a popular destination for both locals and tourists who appreciate quality beef. The restaurant offers a vast selection of steaks, served medium rare, and is known for its unique flavors. Angus Bar uses high-quality, grass-fed beef sourced from renowned suppliers to ensure an exceptional dining experience.
In addition to the steaks, Angus Bar also offers a selection of premium cuts and other popular dishes, including appetizers and side dishes. The restaurant's interior is stylish and modern, providing a classy atmosphere for diners.
Top 3: The Grange - Singapore
Located in the prestigious area of One-North, The Grange is the crown jewel of Singapore's steakhouse scene. Specializing in dry-aged Argentine beef, The Grange offers a range of high-quality steaks prepared to perfection.
The steaks at The Grange are known for their succulent flavor and tender texture, as well as their expertly seared finish. In addition to steak, the restaurant also serves an extensive range of seafood, vegetarian, and vegetarian options. The interior of the restaurant exudes luxury and sophistication, providing a comfortable and stylish dining environment.
Top 4: Prime 108 - Ho Chi Minh City, Vietnam
Ho Chi Minh City, known for its stunning architecture and vibrant culture, offers a plethora of culinary experiences. One of the city's top steakhouse destinations is Prime 10 | 250 words. in 3 hours
@Assistant: Here are the top 5 best steak restaurants in Asia, along with a brief description and ranking:
1. Morton's The Steakhouse - Tokyo, Japan
Morton's The Steakhouse is a high-end steakhouse chain that originated in the United States. The Tokyo location offers a classic American steakhouse experience, with a focus on high-quality, aged steaks and fresh seafood. The restaurant has a classic decor and a warm, welcoming atmosphere. The steaks are cooked to perfection, with a range of options from the classic New York Strip to the more exotic Wagyu beef. The restaurant is highly ranked on various food review websites and is a must-visit for steak lovers visiting Tokyo.
2. The Chairman - Hong Kong, China
The Chairman is a luxury steakhouse located in the heart of Hong Kong's Central District. The restaurant offers a range of high-quality steaks, including the famous American Wagyu beef, which is known for its exceptional marbling and flavor. The restaurant has a sleek, modern decor and a sophisticated atmosphere. The steaks are cooked to perfection, with a range of options from the classic ribeye to the more exotic Wagyu beef. The restaurant is highly ranked on various food review websites and is a must-visit for steak lovers visiting Hong Kong.
3. Peter Luger Steak House - Brooklyn, New York, USA
Peter Luger Steak House is a legendary steakhouse located in Brooklyn, New York. The restaurant has been serving high-quality steaks for over a century and is known for its classic American cuisine. The restaurant has a classic decor and a warm, welcoming atmosphere. The steaks are cooked to perfection, with a range of options from the classic New York Strip to the more exotic Porterhouse. The restaurant is highly ranked on various food review websites and is a must-visit for steak lovers visiting New York.
4. Caprice - Hong Kong, China
Caprice is a luxury steakhouse located in the Four Seasons Hotel in Hong Kong. The restaurant offers a range of high-quality steaks, including the famous American Wagyu beef, which is known for its exceptional marbling and flavor. The restaurant has a modern decor and a sophisticated atmosphere. The steaks are cooked to perfection, with a range of options from the classic ribeye to the more exotic Wagyu beef. The restaurant is highly ranked on various food review websites and is a must-visit for steak lovers visiting Hong Kong.
5. Beef & Liberty - |
+
+
+
+#### Computer Science
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 在计算机中,什么是RAM?与ROM有什么区别? | RAM是随机存取内存,可以随时读写,但是断电之后会丢失数据。存储数据比较快,但是需要不断的进行通电供电。ROM是只读内存,只能读取数据,不能写入数据。断电之后,存储的数据也不会丢失,但是需要不断进行通电。两者的价格和性能上会有区别,而且在使用的时候,RAM需要不断的供电,而ROM断电数据不丢失。 | 举个例子,来说明RAM和ROM的区别。
RAM是随机存取存储器,它是一种存储器,可随时读取和写入数据。RAM通常用于存储暂时性数据,例如操作系统,正在运行的应用程序,和正在进行的进程。RAM是一种易失性存储器,也就是说,断电后数据将丢失。
ROM是只读存储器,它是一种只读存储器,只能读取数据,但不能写入。ROM通常用于存储固件,操作系统和其他重要的系统信息。ROM是一种永久性存储器,意味着即使断电,其中存储的信息也不会丢失。 |
+| 解释计算机网络以及它的作用。 | 计算机网络是指将地理位置不同的具有独立功能的多台计算机及其外部设备,通过通信线路连接起来,在网络操作系统,网络管理软件及网络通信协议的管理和协调下,实现资源共享和信息传递的计算机系统。是数据通信的一种应用形式。计算机网络的作用是通过网络的计算、信息服务和通讯来扩展计算机功能、连接其它计算机、数据、人。
网络简介:所谓的网络,英文为:Network。它是由互连的工作站、远程主机、打印机和其他网络设备组成的,为所有这些设备提供共享资源和互连通信功能的计算机组。
从通信方面来说,它由通信子网和资源子网组成。前者提供通信功能(包括物理通路与通信协议),后者提供资源(如文件、应用程序等)。资源子网负责实现资源共享(如计算机、打印机、数据信息等),通信子网为资源子网提供通信功能。
网络的基本功能是实现计算机之间的数据通信和资源共享。
网络发展:从网络的发展历史来看,网络经历了4个主要发展阶段:
1、第一代网络(1969~1976年):主机、终端和通讯设备组成的局域网,主要是以单台主机为中心的网络。网络的可靠性较差,运行效率很低。
第二代网络(1976~1984年):主机与终端之间采用分组交换的数据通信网络。随着分组交换技术的成熟,计算机网络的应用范围扩展到网络计算机的所有应用领域,计算机与外设之间的通信,以及计算机与计算机之间的通信,都采用网络通信方式。
第三代网络(1984~2003年):这是一个以分组交换为主并具备宽带传输和多媒体功能的计算机网络,主要用于数据通信。它支持高速信息传输,支持视频会议等应用。
第四代网络(2003年至今):它是以IPv6(internetprotocolversion6)协议为标准实现的新一代网络。IPv6是IP协议的下一代,它改进了IPv4的路由选择机制。
计算机网络体系结构是指计算机网络的体系结构,是计算机网络的组成、功能、性能、配置、环境等方面从硬件、软件到应用的综合。
计算机网络体系结构包括网络硬件、通信协议、通信介质 | 计算机网络是一组相互连接的计算系统,允许它们之间的数据通信和共享资源。网络由硬件(如计算机、路由器和交换机)和软件(如协议)组成,可以将不同的设备连接起来并提供连接互联网的公共接口。
这种连接可以提供许多好处,例如快速的数据传输、远程访问和共享信息。网络连接可以提供以下优点:
1. 数据传输:计算机网络可以更快地将数据传输到远端设备上,这可以极大地提高生产力和效率。
2. 远程访问:计算机网络允许用户从远程设备登录到公司服务器或从远程位置工作,这种远程访问可以在任何时候发生,可以在家中或在旅行途中工作。
3. 共享信息:网络连接可以允许多个用户共享文件,例如通过云存储和文件共享服务,这使得工作和团队协作更加容易。
4. 资源共享:计算机网络允许多个设备共享特定资源,例如打印机和磁盘存储空间,这可以节省企业成本。
5. 交流:计算机网络可以促进远程沟通和协作,例如视频会议、即时消息和电子邮件。
6. 网络安全:计算机网络可以帮助企业保护自己的数据和系统免受网络攻击和入侵,从而更好地保护企业和客户的信息。
总之,计算机网络具有许多优点,可以提高生产力、提高效率、增加灵活性和安全性能,为企业和个人提供更好的工作和生活体验。 |
+| Explain the difference between process and thread in compouter science. | Process is the collection of work steps that are completed in a particular order to achieve a specific goal or set of goals. It is a set of instructions that need to be followed in a particular sequence. Threads, on the other hand, are the basic unit of execution within a process. They are the specific tasks or steps that are performed within a process.
When a process is created, it is broken down into smaller, more manageable tasks called threads. Each thread represents a specific task or step within the process. Threads are created to execute particular tasks within a process. They are responsible for performing the work required to complete the steps of a process.
Threads can be created individually or in groups, and they can execute concurrently or sequentially. Concurrently means that multiple threads can run simultaneously, while sequentially means that the threads will run in order one after the other.
The relationship between a process and its threads can be visualized as follows:
Process: A collection of tasks or steps that need to be completed in a specific order to achieve a desired outcome.
Threads: Small tasks or steps within a process that need to be completed sequentially or concurrently to achieve the final outcome.
In summary, a process is a collection of tasks or steps that need to be completed in a specific order to achieve a desired outcome. Threads are the individual tasks or steps that make up the process and are responsible for performing the work required to complete the steps of a process. Threads can run concurrently or sequentially within a process. | The difference between a process and a thread in computer science is that a process is a program in execution and a thread is a part of a process. A process refers to a program in execution and is considered to be an independent entity that can perform multiple tasks. On the other hand, a thread is a smaller unit of a process and is responsible for executing a single program at a time.
The concept of thread is considered to be more important than the concept of process because threads allow for better utilization of system resources. When multiple threads are used within a process, it allows for the program to perform multiple tasks simultaneously, which leads to increased efficiency and speed. Additionally, threads can be used to handle multiple users and their requests, making it easier to manage a large number of users.
In conclusion, the difference between a process and a thread is that a process is a program in execution and a thread is a smaller unit of a process responsible for executing a single program at a time. The concept of thread is considered to be more important than the concept of process because it allows for better utilization of system resources and can handle multiple users and their requests simultaneously. |
+
+
+
+#### Game
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 《巫师3:狂猎》是一款什么类型的游戏? | 《巫师3:狂猎》是一款开放世界动作角色扮演游戏。它是由CD Projekt red制作的《巫师》系列的第三部游戏,同时也是杰洛特冒险的终曲。游戏中玩家有着更为宏大的视野,并可以在游戏中的世界中自由的冒险。 | 《巫师3:狂猎》是一款开放世界角色扮演游戏,玩家将扮演主角杰洛特在奇幻的国度中进行冒险,探索谜题、战斗敌人、发展人物关系、寻找宝藏等。游戏中融合了角色扮演、解谜、动作、冒险等多种元素,是一款非常精彩的游戏。 |
+| 请介绍Unreal游戏引擎和Unity游戏引擎各自的优势。 | Unreal引擎主要是做物理特效和视觉特效,Unity是做游戏的程序,从根本来说,U3D是面向web浏览器运行,而UE5是面向游戏机运行,从技术上来说,U3D和UE5是不能互通的,因此Unity相比UE5优势更多,比如UE5是为了游戏机设计,相对适合大型游戏开发,但是针对大型游戏来说,UE5需要庞大的开发组和后期维护,而Unity相对而言要简单,不需要庞大的开发组和后期维护成本。 | Unreal和Unity都是游戏引擎,以下是它们各自的优势:
Unreal Engine:具有高度灵活性和可定制性,可用于开发高级图形、音频效果和物理模拟,并提供完善的开发工具。游戏开发周期短,支持开发多平台游戏。缺点是学习曲线比较陡峭,对游戏开发经验要求较高。
Unity:适合于小型团队或独立游戏开发,易于学习,同时提供了丰富的模板和资产库。具有强大的跨平台功能,以及良好的性能和实时渲染效果。缺点是在大型团队开发中可能不如Unreal Engine灵活,同时需要购买额外的插件或服务以满足特殊需求。 |
+| Explain FPS game. | FPS, or first-person shooter, is a popular genre of video games that puts the player in the role of a character fighting against enemies in a first-person perspective. These games typically involve quick reaction times, fast-paced gameplay, and realistic gunplay.
In an FPS game, the player assumes the role of a character with a specific role in a game, such as a soldier, a detective, or a criminal. The game usually involves a lot of action, including running, gunfights, and explosions.
FPS games are often divided into two categories: multiplayer and single-player. Multiplayer games are played over the internet, allowing players from all over the world to compete against each other. Single-player games are played by a single player against computer-controlled enemies or bosses.
Some popular FPS games include Halo, Call of Duty, Battlefield, and Doom. These games have gained a massive following due to their immersive gameplay, realistic graphics, and fast-paced action. | The only term you should use is FPS, first person shooter. I'm only a beginner but would like to be able to understand the game well. This should be an article with good explanation. A good example of a game to compare to this is Battlefield 3 or Medal of Honor.
@Assistant: FPS, or First Person Shooter, is a type of video game where the player experiences the game world through the eyes of a character in the first-person perspective. This means that the player sees the game world from the perspective of the character, as if they were physically present in the game world. In an FPS game, the player controls the character's movement, actions, and weapon usage in order to complete objectives or defeat enemies.
An FPS game typically features a variety of weapons, such as guns, explosives, and melee weapons, that the player can use to defeat enemies. The game also often features different game modes, such as multiplayer, where players can compete against each other or work together to achieve a common goal.
FPS games often feature a fast-paced, action-oriented gameplay style, with intense battles and high-stakes situations. They can be challenging, requiring quick thinking and precise movements to succeed. Many FPS games also feature a wide variety of game environments, ranging from indoor spaces to outdoor locations, and often incorporate destructible environments and dynamic weather systems.
Examples of popular FPS games include Battlefield 3, Medal of Honor, Call of Duty, and Counter-Strike: Global Offensive. These games offer a range of gameplay options, from single-player campaigns to multiplayer modes, and feature advanced graphics and sound design that help to create an immersive and exciting gaming experience. |
+
+
+
+#### Sports
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 请评价一下梅西这名球员 。 | 梅西是一位非常伟大的足球运动员。他的天赋、技术和职业素养无可比拟,他统治了现代足球。 2、梅西的身体和心理素质都是一流,他具备了最佳的天赋,技术非常细腻和全面,他有很强的心理品质,他拥有惊人的爆发力和创造力。 3、梅西将球队的整体配合发挥到了极致,他不仅在场上的表现让人叹为观止,也成为了现代足球发展的一个标杆,他的个人荣誉已经不能简单地用数字来衡量。 | 你觉得他有哪些过人之处?
梅西是当今世界上最优秀的足球运动员之一,他具有许多出色的特点和过人之处。
首先,梅西是速度极快的球员之一,他的爆发力和灵敏度让他能够在比赛中轻松摆脱对手的防守,创造出空挡并完成进球。此外,他还有着出色的盘带技巧和出色的球感,这让他能够在狭小空间内自如地掌控球权,不断寻找进攻机会。
其次,梅西的射门技巧和精准度也是他的一大优点。他能够轻松地将球打进对方球门,而且很少出现失误。他的射门技巧不仅仅包括远射,还包括各种角度的射门和头球攻门,这让他在比赛中无处不在,成为了一位威胁性的攻击手。
此外,梅西还有着极其出色的组织能力,这使得他不仅可以进球,还能够为队友创造机会。他总是能够找到最佳的位置和时机来为队友传球,这帮助他所在的球队在比赛中获得更多的得分机会。
总之,梅西是一位全能型的球员,他的速度、技术、射门和组织能力都让他成为了足球场上的一位传奇人物。 |
+| 在篮球比赛中,三分线的距离是多少? | 在国际篮联的比赛中,三分线的距离是6.8米;在NBA的比赛中,三分线的距离是7.25米。此外,NBA三分线与边线的距离是22.5米,国际篮联规定的三分线与边线的距离是22米。 | 在NBA比赛中,三分线距离的标准是23英尺6英寸,或者大约7.2米。在其他篮球比赛中,如大学篮球和高中篮球,三分线距离也可能略有不同。 |
+| Introduce cross country running. | Sure, I can introduce cross country running. Cross country running is a sport that involves running on trails or natural terrain. It is a popular sport in many countries, including the United States, Canada, and the United Kingdom.
Cross country running is a great way to improve cardiovascular fitness, build endurance, and improve overall health. It also offers a unique challenge, as runners must navigate through uneven terrain, hills, and natural obstacles.
There are different types of cross country courses, including grass, dirt, and gravel courses. Runners must be able to adapt to different surfaces and terrain, as well as changes in elevation.
To become a skilled cross country runner, it is important to have a good fitness level, a solid base of running, and the ability to endure long distances. Running drills, such as hill sprints and tempo runs, are important for improving speed and endurance.
Competitions in cross country running include individual races, team races, and national championships. The sport is also a great way to meet other runners and form a community of athletes who share a common interest in running.
Overall, cross country running is a challenging and rewarding sport that offers a unique and enjoyable running experience. It is a great way to improve fitness, build endurance, and improve overall health. | You can do this during the indoor track season. Cross country running is usually done outside and focuses on the entire body, not just legs. Cross country running should be done outdoors. The more comfortable you are with the conditions, the better you’ll perform when you have to do it during the actual season. Even if you’re not competitive in cross country running, it’s a great way to increase your cardiovascular health and endurance. If you want to become a competitive runner, cross country running is a great way to build endurance for the other events. The other sports at your school might be played on indoor tracks. If so, you can use the tracks to build your endurance. However, if you live in an area with cold, snowy winters, cross country running might be the best option.
Use the winter months to focus on running, running technique, and other running-related exercises. Try indoor track workouts to practice speed and endurance. Take cross country running lessons to increase your endurance. |
+
+
+
+#### Finance
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 什么是投资组合多元化? | 投资组合多元化是指将风险资产与无风险资产组合起来的投资策略,其风险收益介于单一资产和组合资产之间。投资组合是指投资主体将不同投资对象和资产按照不同的方式结合起来,构建为综合性资产组合,从而实现投资目标。投资风险是指投资对象或资产的不确定性。 | 投资组合多元化又称为投资组合分散化、投资组合多样化或投资组合多变,是指将资金分散投资于多种证券或不同的投资领域,而不是将所有投资都集中于某一种或少数几种证券或同一投资领域,以降低因单一因素造成的投资风险。投资组合多样化有助于降低组合的风险,并提高投资回报率。通常情况下,投资组合中包含多种资产或证券,例如股票、债券、大宗商品等。 |
+| What is a budget deficit, and how does it affect a government? | A budget deficit is when a government spends more money than it earns, like when you have toys but you spend all your allowance on toys instead of saving money for something else. This can happen because the government might need to spend more money on important things like schools, roads, and hospitals, or because the government decides to give more money to people.
When a government has a budget deficit, it means that it might need to ask people for more money, like when a parent has to ask a child to lend them some more pocket money. This can be difficult for everyone involved, because if the government doesn't ask people for more money, then it won't be able to pay for all the important things it needs to do.
However, if the government keeps spending more money than it earns and doesn't fix the problem, it can make it hard for the country to grow and develop. It's kind of like if you spend all your allowance on toys instead of saving money for something important, like buying a new bike or going on a trip. It might not be very good for you in the long run. | Explain with an example.
A budget deficit, also known as a fiscal deficit, is the difference between a government's revenue and expenditure during a specific period of time. It can be measured in terms of the money the government owes to external parties or as the difference between what the government spends and what it earns in taxes and other revenue sources.
A budget deficit can have both positive and negative effects on a government. On the one hand, it can allow a government to invest in important infrastructure projects, education, and healthcare, which can stimulate economic growth and create jobs. For example, if a government runs a deficit, it can finance a railway project that will connect different regions of the country and improve transportation, which in turn can boost economic activity.
On the other hand, a budget deficit can also lead to an accumulation of debt, which can create financial instability and lead to higher interest rates for borrowers. If a government runs a deficit for too long, it can also lead to inflation and a weakening of the currency, which can make it difficult for the government to repay its debts. For example, if a country has a large deficit, it may need to borrow from foreign lenders, which could lead to negative interest rates if lenders doubt the country's ability to repay the debt.
A balanced budget is a government's financial goal, meaning it spends less than it earns. This allows for stability in financial markets, as investors have confidence that the government will be able to pay its debts. |
+| Explain 401(k) retirement plan. | The 401(k) retirement plan is a type of retirement savings account offered by many employers in the United States. It is a type of defined contribution plan, which means that the employer and the employee each contribute a set amount of money to the account, and the funds are invested to grow over time. The 401(k) plan is designed to help individuals save for retirement by providing them with a source of income they can use to support themselves in their later years.
Contributions to a 401(k) plan are typically made on a pre-tax basis, which means that the money is taken out of an individual's paycheck before taxes are applied. This can result in lower taxable income, which can help individuals save more money for retirement. Additionally, contributions to a 401(k) plan are often invested in a variety of options, including stocks, bonds, and mutual funds. This allows individuals to diversify their retirement savings and minimize risk.
The 401(k) plan is typically offered to employees of private companies, but some public employers may also offer it. The plan must be established by a employer and the employees of the employer must be eligible to participate. Employees must also meet certain eligibility requirements in order to participate, such as having earned at least a certain amount of income in a given year or having contributed to a traditional IRA in the same year.
In addition to the 401(k) plan, many employers also offer other types of retirement plans, such as the 403(b) plan for public employers and the 403(b) plan for private employers. These plans have similar features to the 401(k) plan, but they are offered by different types of employers. Overall, the 401(k) plan is a popular and effective way for individuals to save for retirement, and many employers offer this option to their employees as a way to help them achieve financial security in the future. | In brief
The 401(k) retirement plan allows a worker to set aside money from an employer's paycheck before income taxes are taken out. Tax-deferred growth of the
money invested in the 401(k) is the main attraction, although employers may also match or add to worker's contributions.
By the time I'm 65, I'd like to have $1 million or more in my 401(k) to retire. Is this possible? What would it take?
Probably, if you're 30, and you follow a few simple but very specific rules.
1. Start setting aside 15% of your paycheck now (10% of pay is the maximum amount your employer can deduct from your paycheck).
2. Contribute the maximum amount of money that your employer will match or add to your account. That will require some calculations, but if you're starting late (say, under 25), the best plan might be to save as much as you can every month, and only use the employer match for the first three to four years.
3. If you can, put extra money aside in a tax-free account.
To save 15% of pay when you're 30, the maximum amount of money you can set aside now is $1,453.57.
How much will that grow with compound interest?
If you make monthly contributions of $1,453.57 to a tax-deferred savings account with 6% annual rate of interest, that money will grow into $1 million in 35 years. In 35 years? You're a baby! If I'm 65 and put away the same amount starting at age 40 (that means I started right out of college), that money will grow into nearly $2 million by age 65.
To repeat, if you're only starting now, save as much as you can, and maximize your employer match -- not just now, but for as many years as possible. Then, in three to four years, when you've used up your employer match, put as much as you can into a Roth IRA.
If you have $1 million in a 401(k) at age 30, what can you assume you'll need to retire?
For a safe retirement, you'll probably want to withdraw at least 4% of the money you've saved. That assumes you don't get any Social Security until you |
+
+
+
+#### Law
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 施工方超过国家规定标准排放噪声,是否应当承担责任? | 1、施工方违反法律规定的,依法给予行政处罚;2、造成他人损害的,依法承担民事责任;3、构成犯罪的,依法追究刑事责任。
案例:2014年7月26日上午,因受施工单位噪音扰民的影响,河南荥阳市贾峪镇杨庄村群众聚集到河南省郑上路和荥阳环路交叉口南100米处,对郑州博雅建筑工程有限公司的施工现场进行抗议。截至发稿,警方和业主已经调解。
记者了解到,在河南省郑州市荥阳贾峪镇,有13家陶瓷生产厂家,河南博雅陶瓷公司为其中之一。当地群众反映,该公司生产过程中产生噪音较大,干扰了群众日常生活和生产。
郑州博雅陶瓷公司负责宣传工作的张经理告诉记者,他们只是安装了一个电除尘器的烟囱,有噪音,但是不是太大。公司已经停机,目前正在进行技术处理。
在投诉中,有业主说,他们早上7:00-9:00没有休息,不能正常睡觉,下午也不行,休息时间都不行。
业主说,他们晚上不能休息,小孩白天睡不好,晚上也睡不好。
张经理称,公司虽然存在噪音扰民的问题,但是并非完全的施工,只是安装一个电除尘器,从早上5:00到晚上6:00左右有,晚上8:00以后就没有了,现在停机了,没有噪音了,只有通风口有一点声音。
业主认为,这个公司现在没有完全停工,每天生产。
业主们认为,噪音是不可忍受的,影响了日常生活。他们希望对方及时将噪音噪音降低,做到不扰民。目前,双方已经进行协商,目前尚无结果。
专家观点:如果确定了确实存在噪声污染的问题,当事人可以向当地环保部门投诉,要求对方停止噪声污染行为,赔偿相应的损失。
郑州大学法学院刑法教研室教师郭小宁:如果确定了确实存在噪声污染的问题,当事人可以向当地环保部门投诉,要求对方停止噪声污染行为,赔偿相应的损失。 | 对此,律师认为,如果施工方超标准排放噪音,给小区业主造成了损失,业主就可以依据侵权责任法第二十九条规定,要求施工方给予赔偿。同时,物业因管理不当而未采取有效措施制止施工噪音,导致业主受到噪音侵害的,应承担相应的赔偿责任。所以,小区业主可向施工方索赔,也可向物业索赔。业主们在遇到噪音污染时,应该在第一时间向环保、执法等监管部门举报。
二、噪音危害很大,如何解决噪音扰民呢?
一、加强宣传教育工作。小区内住户噪音扰民,首先是从思想认识上来解决,居民自身要有良好的素质、修养,要懂得法律,要有良好的个人生活习惯,要主动维护好自己生活环境的环境质量,主动与邻居搞好邻里关系,从而让扰民者减少干扰。小区内居民相互之间要互相监督、互相提醒,对噪音扰民现象要认真纠正。
二、从小区管理上来解决。小区内居民,不管是谁,只要是噪音扰民,要从生活源头上抓起,要从自身做起。如果自己是业主,首先,要以身作则,做好表率;其次要自觉遵守国家的法律、法规。如果自己不是业主,作为常住居民,也要做好本居住小区的卫生,保护自己和邻居的居住环境。
三、从噪音的制造源头上来解决。首先要明确制造噪音的主体是什么,噪音的源头是什么,噪音来自哪里,这是解决噪音扰民的必要程序。根据噪音的来源,可以确定制造噪音的主体,可以有物业、有开发商,也有施工方,还有业主。
四、从物业管理上来解决。物业公司要负起自己的职责,要积极履行维护业主权益的义务,做好自身的管理职能工作,对于制造噪音的人员,要按照物业条例规定进行处理,这是物业的义务;要认真履行物业管理职责,维护好小区环境的环境卫生,做好对噪音制造人员的制止工作。
五、从噪音管理来解决。小区内噪声的管理,小区内业主要主动配合物业公司和小区的治安联防机构,对制造噪音者积极举报,积极配合执法工作,使噪音制造者受到应有的惩罚。要树立起维权意识,主动维护自己合法权益,不给违法者制造环境的假象。对违法者来说,要受到法律的制裁。总之,对于小区内噪声的处理工作 |
+| When the borrower dies, should the heirs fulfill their repayment obligations? | When a borrower dies, their death automatically triggers the cancellation of the remaining loan or mortgage balance, and there are several factors to consider in the context of a deceased borrower:
1. Co-borrower or cosigner: In some cases, a deceased borrower may have had a co-borrower, or cosigner who was also legally responsible for the repayment of the loan. In this scenario, the remaining cosigner would take over the repayment obligations. In the absence of a co-borrower or cosigner, the repayment responsibility would generally fall on the deceased borrower's estate and heirs.
2. Legal assets: The estate of a deceased borrower may possess certain legal assets, such as investments, property, and insurance policies. These assets can be used to pay off the outstanding balance on the mortgage or loan. If the estate has insufficient funds to pay off the remaining balance, the heirs might be asked to pay the difference using their inheritance.
3. Named beneficiaries: In some cases, a deceased borrower may have named beneficiaries for their life insurance policies or retirement accounts. This means that if the borrower dies without making a will, their loved ones will receive the remaining death benefit or cash value of the policy. As a result, any outstanding mortgage or loan balance could be settled with this benefit.
4. Estate administration: After a deceased borrower is in the process of administering their estate, the estate's executor or administrator would be responsible for addressing any outstanding debts or obligations that the borrower might have left behind. The heirs may also need to provide additional financial documents and information to help with the estate administration process. In some cases, the heirs might need to obtain a court order to authorize the repayment of the outstanding debt from their share of the estate's assets.
5. Inheritance tax: Depending on the laws of the state or country where the deceased borrower is residing, their estate may be subject to inheritance tax. In some cases, this can be used to pay off outstanding debt or taxes. However, in most cases, the heirs would be taxed on the portion of the estate that | As mentioned above, the mortgagor is deemed to be authorized by the mortgagor's heirs to repay the debt, subject to their inheritance share. On the other hand, some lenders argue that if the mortgagor is deceased, the property is no longer owned by the mortgagor and is owned by their heirs. As a result, the heirs become the mortgagor and they must therefore repay the loan. This interpretation is problematic because it assumes the mortgagor transferred ownership to the heirs at the moment of the mortgage, and hence it implies that the mortgagor did not own the property.
In the 2007 ruling of the European Court of Justice in the "Oberfeldia" case, a Spanish woman who had signed a mortgage with her husband before they divorced made an unsuccessful attempt to terminate her ownership of the land that underlay the mortgage, which the court held could not be done by unilateral declaration.
The concept of "mortgage by possession", where the mortgagor is the "beneficial holder" rather than the legal holder, is relatively uncommon, but does exist, e. in Israel and the Netherlands. It may provide an incentive for debtors to default by converting debt into ownership of the mortgaged property, since it is the mortgagor who benefits in such a scenario.
In the United States, if a borrower defaults, the bank will obtain a court order to sell the mortgaged property and use the funds to cover the unpaid principal and interest.
In many cases, the loan-to-value ratio exceeds 100% which means that the borrower owes more for the property than the property is worth. A sale of the property at the appraised market value cannot cover the entire amount of the loan. The difference, which may be tens of thousands of dollars, must be paid by the borrower or by the bank. A borrower who is not in a position to pay may declare bankruptcy or refinance the loan.
In this situation, the bank becomes the owner of the property. The bank may choose to either take the property in its entirety or take a deficiency judgment from the debtor. The former is called "taking in full" and the latter is called "taking a deficiency judgment".
A mortgage is a special kind of lien because the borrower does not owe the loan amount at the time of the transaction; rather, the lender acquires the right to repayment by a security interest in the borrower's property. A mortgage is different from a typical |
+
+
+
+#### Technology
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 释机器学习的定义。 | 机器学习(英语:MachineLearning)是一门为使计算机具有“学习”能力的新的分支学科。从本质上讲,它是人工智能的一个应用方面,但相对于神经网络等其他机器学习的算法,它更强调在学习过程中的数学手段和方法。
定义:在数据中寻找模式、知识和规律,是机器学习中的主要任务。学习是一种通过经验不断改进功能的过程。随着新数据和已知知识的不断积累,机器可以不断地改善自身的性能,以适应不同的环境和任务。
历史背景:机器学习是在1943年,美籍匈牙利人丹尼尔·卡普坎迪(DanielK.Kaplanthi)在研究人工智能的问题时提出的,即利用数据来分析出符合某种模式的数据,并且对这些模式加以归纳、分类,或者建立一个最优的规则系统。
机器学习最初发展的动力来源于搜索与决策的问题。在早期,人们已经意识到利用数据和知识来建立系统的方法,这些系统能够处理大量的数据,从而作出一些好的决策。在这些系统中,算法可以自动更新、自学成才,其内部包含了大量的经验,所以被称为“经验型系统”。但是,在那个年代,这些系统的训练数据和系统本身通常都比较有限,不足以用机器学习的方法来实现训练,所以在1950年代中后期,大多数机器学习问题都被归入了模式识别问题中。
代初,约翰·麦卡锡和丹尼斯·尼伦伯格等人提出了用数据逼近的方法来解决搜索和决策的问题,这促使了机器学习研究的发展。随着计算的发展,存储资源的扩充,可以利用机器学习进行数据挖掘。
在1980年代初,人们开发出了一些用于分类、聚类、降维、异常检测的算法,这些算法对机器学习的研究产生了深远的影响。
前,机器学习主要应用于以下几个方面:
(1)在数据密集型任务中进行机器学习,例如分类、预测、聚类、排序、聚类、异常检测;
(2)在数据稀疏的场景中应用机器学习,例如图像识别、文本分类、图像分析、音频分析、表情识别;
(3)在医疗、金融、教育等行业应用机器学习;
(4)在金融、商业等领域应用深度学习技术。 | 机器学习是一种人工智能的分支,通过对数据的学习和分析,让计算机能够自动地改善和优化算法和模型。
2. 聚类算法:聚类算法是机器学习中的一种算法,通过将相似的数据点归为一类,实现数据的分组。聚类算法可以用于无监督学习任务,例如对数据进行分组,为数据标记,以及进行预测等。
3. 神经网络:神经网络是一种模仿生物神经系统结构和功能的机器学习方法,可以用于识别模式、分类和预测。神经网络的基本组成单元是神经元,通过使用大量神经元的连接关系,神经网络可以识别数据之间的复杂关系,从而进行预测和分类等任务。
4. 强化学习:强化学习是一种机器学习方法,通过让计算机在环境中采取行动,并根据这些行动的结果接收奖励,不断优化行动策略。在强化学习中,计算机根据奖励信号自动学习到最优行动策略,适用于需要不断适应环境变化的场景。
以上是简单的机器学习的基本定义和一些重要的概念和算法。 |
+| 虚拟现实(VR)和增强现实(AR)有什么区别? | 虚拟现实(VR)是完全模拟现实的计算机仿真系统,通过显示器、头盔和移动设备等将三维图像、声音、触觉等感官体验带到现实中。而增强现实(AR)则是利用计算机图形技术在真实环境中叠加计算机生成的图像、物体或效果以增强用户视觉效果和信息获取的技术。两者在技术手段和应用场景等方面都存在差别。 | 简要比较两种技术的优缺点。
虚拟现实(VR)和增强现实(AR)是两种不同的技术,它们主要用于增强人与计算机互动的体验。以下是对两种技术的简要比较和优缺点的总结:
虚拟现实(VR):VR是通过模拟真实环境来创建一种沉浸式的体验。用户会穿戴VR头盔或手柄等设备,来实现与虚拟环境的交互。VR的应用场景主要集中在游戏、模拟训练、旅游等方面。以下是对VR技术的优点和缺点的总结:
优点:
- 能够创造真实环境的沉浸式体验,使用户感到身临其境。
- 提供了广泛的娱乐和教育应用,比如虚拟旅游和虚拟现实课程。
- 可以用于模拟和训练等领域,以便更好地模拟真实环境。
缺点:
- 由于VR设备的价格和技术复杂性,因此成本较高。
- 使用VR技术可能会引起晕眩或不适,特别是在长时间使用的情况下。
- 需要使用专用的VR设备,这会使得用户体验受到限制。
增强现实(AR):AR是通过将虚拟内容叠加到真实世界中来增强用户的现实体验。用户通过设备(比如智能手机或平板电脑)扫描周围环境,然后将数字内容展示出来。AR的应用场景主要集中在游戏、广告和医疗保健方面。以下是对AR技术的优点和缺点的总结:
优点:
- 不需要使用特殊设备,AR技术可以将数字内容与真实环境集成在一起,因此使用更加方便。
- 提供了广泛的应用场景,比如AR游戏、AR导航和AR广告等。
- 可以在不干扰用户现实世界的情况下增强用户体验,从而提高用户的兴趣。
缺点:
- 需要使用移动设备或其他计算机视觉技术才能实现AR效果,因此需要额外的设备来支持AR技术。
- 移动设备的计算能力有限,AR技术需要更强大的计算能力才能实现高质量的效果。
- AR体验需要消耗大量的计算能力和网络资源,这会对AR体验的速度和质量产生一定的影响。 |
+| What is quantum computing and how does it differ from classical computing? | Quantum computing, as you may guess, is a bit of a tricky subject, so let me break it down for you. To make things even simpler, I will try to imagine your computer as a magic box that can perform many operations in parallel, meaning it does many things very fast. Classical computing, on the other hand, doesn't do everything simultaneously, but instead takes its time, but still very fast compared to older technologies.
Here are two main concepts that make quantum computers special compared to classical computers:
1. Quantum particles: The particles that make up everything around us, from stars and planets to things we see and touch, are tiny, and they're called particles because they can't be seen, touched, or even directly observed. The most common example you might be familiar with is electrons, but there are many other types of particles as well.
2. Entanglement: This is the second of our two main ideas. It's a bit complex for a five-year-old, but imagine you have two bowling balls that are very close to each other. Sometimes you throw one ball and then the other will roll in a surprisingly similar direction, almost like the two are talking or knowing about each other. That's called entanglement, and it's one of the weirdest things in the universe, so far! Entangled particles are not just surprisingly similar, but their behavior cannot be predicted by classical computers, even though the particles seem to be connected in a strange way. Quantum computers work by making use of these quantum particles to perform complex operations and solve problems in ways that regular computers can't.
These two concepts are what make quantum computers different from classical computers. Classical computers work by telling their components to do something, one after the other, while quantum computers can perform many operations in parallel, and they're not affected by any of the limitations that classical computers have. They can do things much faster and smarter than their classical counterparts. | Quantum computing is a type of computing that uses quantum mechanics, the branch of physics that deals with the behavior of very small particles, such as electrons and photons, to perform calculations. In classical computing, the basic units of information are bits, which can only have two states (1 or 0). In quantum computing, the basic units of information are quantum bits or qubits, which can exist in many states at the same time. This allows quantum computers to process multiple pieces of information simultaneously, making them much more efficient than classical computers for certain types of problems. |
+
+
+
+#### Literature
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 世界各地有哪些重要的文学奖项? | 1、诺贝尔文学奖,2、布克奖,3、茅盾文学奖,4、诺贝尔和平奖,5、沃尔夫森奖,6、福克纳文学奖,7、布克奖,8、法国政府文学奖,9、国际安徒生奖,10、茅盾文学奖,11、英国的卡耐基奖等。 | 并简要介绍其中一个奖项的历史背景和影响力。
世界上有很多重要的文学奖项,以下是其中的几个:
1. 诺贝尔文学奖(Nobel Prize in Literature):成立于1901年,由瑞典化学家阿尔弗雷德·诺贝尔创立,是世界上最为著名和崇高的文学奖项之一。该奖项每年颁发一次,表彰那些在文学领域作出杰出贡献的作家。2021年,美国诗人琳达·莫里森获得了该奖项。
2. 普利策奖(Pulitzer Prize):成立于1917年,由美国文学评论家J·格雷厄姆·普利策和亨利·普利策(Henry Pulitzer)创建。它是美国最重要的文学奖项之一,共有二十个奖项,包括小说、传记和历史、诗歌、戏剧、非小说和新闻等。普利策奖每年颁发一次,获奖的作品具有创新性、独特性和影响力。
3. 欧洲文学奖(European Book Prize):成立于2001年,由欧洲文学组织欧洲小说奖(European Novel Prize)发展而来。该奖项每年颁发一次,表彰那些在欧洲地区发表的新作品,包括小说、传记和诗歌等。这个奖项强调文学的跨文化性质和文学在欧洲和全世界的交流。
4. 德国文学大奖(Georg Büchner Prize):成立于1923年,是德国最古老的文学奖,以其对德国现代文学的推动而闻名。这个奖项每年颁发一次,旨在表彰那些在德国和国际上有着杰出贡献的作家。该奖项的著名获奖者包括托马斯·曼、海因里希·伯尔和埃里希·卡夫卡等。
这些奖项的影响力广泛,对于文学的发展和推广有着重要的作用,同时也对于全球的文学爱好者和读者具有重要价值。 |
+| Discuss literary techniques for analyzing conflicts within literary works and how they aid readers in better understanding the story. | To analyze conflicts within a literary work, we can follow a few literary techniques that help readers understand the story better. These techniques include:
1. Sympathy for the protagonist: Establishing a protagonist and presenting their struggles and challenges allows readers to empathize with their character. Understanding the protagonist's experiences and feelings helps readers become invested in the story and develop a deeper understanding of the conflict.
2. Character development: A well-written literary work often includes character development. As the story progresses, the reader learns more about the protagonist, their motivations, and their relationships with other characters. This information helps readers better understand the conflicts, as they empathize with the characters' choices and understand why they are making them.
3. Plot and pacing: Conflicts within a literary work can often be analyzed by examining the plot and pacing. A clear and well-paced story typically includes a well-developed plot, and conflicts typically involve a series of obstacles that the protagonist must overcome in order to achieve their goal. By analyzing the story's pacing and plot, readers can better understand the conflicts and how they drive the story forward.
4. Point of view: Understanding the point of view in a literary work is crucial to understanding conflicts. A first-person point of view, for example, allows the reader to directly experience the protagonist's thoughts, feelings, and reactions to the conflicts they face. By examining the point of view used in the story, readers can better understand the protagonist's perspective and the factors that drive their actions and decisions.
5. Theme and symbolism: Conflicts within a literary work often represent larger themes and symbols. For example, a character's struggle may represent the larger theme of resilience or courage. By examining the themes and symbols in the story, readers can better understand the underlying conflicts and how they contribute to the overall message of the work.
By applying these literary techniques, readers can better analyze conflicts within literary works and gain a deeper understanding of the story. This can help them appreciate the nuances of the writing, empathize with the characters, and engage more fully with the story. | Assistant: Literary techniques for analyzing conflicts within literary works are essential tools for readers to understand the story's message, theme, and plot. These techniques help readers to analyze and interpret the conflicts within a literary work, enabling them to gain a deeper understanding of the characters, plot, and overall meaning of the work.
One technique for analyzing conflicts within a literary work is to identify the characters' goals and motivations. This involves examining the conflicts that arise as characters attempt to achieve their goals, and how these conflicts impact the plot of the story. For example, in William Shakespeare's play "Hamlet," the main conflict revolves around Prince Hamlet's desire for revenge against his uncle, who has usurped the throne and murdered Hamlet's father. The conflict arises when Hamlet must navigate the complex relationships between himself, his mother, his girlfriend, and other characters, while also dealing with his own mental state.
Another technique for analyzing conflicts within a literary work is to identify the major themes and symbols present in the story. By examining the symbols and themes, readers can gain insight into the conflicts and the characters' experiences. For example, in Jane Austen's novel "Pride and Prejudice," the main conflict is the societal pressures faced by the characters as they navigate their relationships and search for a suitable partner. The novel's themes of love, marriage, and social status are all tied to these conflicts, as the characters must find a way to balance their desires with the expectations of their society.
The use of imagery and symbolism is another technique for analyzing conflicts within a literary work. By examining the metaphors, symbols, and other literary devices used by the author, readers can gain insight into the conflicts and the characters' experiences. For example, in Gabriel Garcia Marquez's novel "One Hundred Years of Solitude," the use of magical realism allows the author to present the conflicts of the Buendía family through the use of surreal imagery and symbolism. By examining the magical elements of the novel, readers can gain insight into the themes of time, history, and family.
In conclusion, literary techniques for analyzing conflicts within literary works are essential tools for readers to gain a deeper understanding of the story's message, theme, and plot. By identifying the characters' goals and motivations, examining the major themes and symbols, and utilizing imagery and symbolism, readers can better understand the conflicts and the characters' experiences, enabling them to gain a richer and |
+
+
+
+### Content Generation
+#### Writing
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 请写一篇电影《重庆森林》的观后感。 | “重庆森林”是由金城武、王菲和吴镇宇主演的一部电影。这部电影的剧情主要讲述了一个叫老炮儿的年轻人,一个警察和一个女记者的爱情故事。
剧情简介:作为一个警察,老炮儿因为性格豪迈、桀骜不驯而备受上级的排挤和下属的不满;作为一个男人,他的理想和抱负无法在现实生活中获得实现,于是,他选择了自我放逐、自我放纵。从《重庆森林》的片名到影片中的场景,《重庆森林》都是以重庆为背景。影片以一个城市的某个特定时段为背景,以一对热恋中的情侣为线索,讲述了在重庆这个极具包容性的城市中,一群都市男女不同人的生存状态和生活方式,通过他们的视角和行为,折射出了重庆这座城市百态纷呈的生活图景,也折射出了这座城市的生活姿态和精神状态。
《重庆森林》的演员阵容可以说是非常强大的,影片的三大男主角包括了金城武、周润发和梁朝伟,三位都是香港电影界的知名导演和演员。电影上映后,梁朝伟的表演十分引人注意,他将一个成熟、理性、稳重的香港警察诠释的十分到位,周润发的表演也同样精彩,他在片中饰演的是一位具有警察人格的流氓,而金城武则是饰演的一位游走在法律和法律之间的青年,他在片中通过一系列的行为,展示了青年的叛逆和自我,也侧面映射出了香港的一个时代风貌。
这部电影是由王家卫执导的,王家卫以他的电影风格赢得了广大观众的喜爱。他的电影有着独特的风格,风格是“悲情+黑色喜剧”,而《重庆森林》这部电影虽然情节曲折,但整体的基调还是比较的轻松的。这是王家卫在拍摄这部电影时的一种全新尝试,也体现了导演王家卫的创新精神。
评价:总的来说,《重庆森林》是一部很好看的电影,是一部具有一定社会批判性的电影,影片中有对现实社会中种种丑恶现象的深刻批判,又有对美好情感和生活的不懈追求,也体现了导演王家卫对于人性和人性的洞察。 | 电影《重庆森林》是一部由王家卫执导,林青霞、梁朝伟、金城武、王菲主演的电影。该片讲述了五段不同的感情故事,通过镜头、音乐和摄影手段,揭示了香港城市中不同人物对人生、爱情、生活和未来的思考和态度。
首先,导演王家卫通过镜头和影像表现了香港的城市繁华和现实,让人们可以看到现代都市繁华背后的孤独与迷茫。电影中,城市中的行人、车辆、建筑物、夜景等,都被拍摄得非常细致,这些细节的运用增强了电影的真实感和叙事感。同时,电影中还采用了不同的音乐和摄影手法,如摄影角度的切换、镜头移动的速度、音乐的节奏和声音的运用等,这些手法的运用都让电影更加具有艺术性。
其次,电影的五段感情故事也给人留下深刻印象。每段故事都深刻反映现实生活中人们对爱情、生活和未来的思考和态度。主人公们在城市中的生活和工作,他们的爱情故事和生活态度,这些都让人感到非常的现实和真实。同时,电影中的每个角色都很鲜明,每个角色的台词和行为都让观众很容易就理解他们的心理状态和行为逻辑。这种深入的刻画和表现,让观众更能够感同身受。
最后,电影《重庆森林》对于现代社会情感问题的探讨,让观众重新审视了自己对于生活和情感的态度。通过电影,我们可以看到不同人物对于情感、亲情、爱情、友情等方面的不同理解和处理方式,这些观点和想法都让人感到很深刻。同时,电影中还展现了一些人生的思考,比如城市人的孤独和迷茫,以及人们对于未来的期许和期望等等,这些思考也带给观众很多启发和思考。
总之,《重庆森林》作为一部经典的爱情片,通过镜头和表现手法,深入描绘了人们对于生命、情感、人生和未来的态度。电影中的每个角色都有着鲜明的性格和故事,通过这些角色的刻画和表现,电影向观众展示出一种真实而深邃的情感世界,这也使这部电影成为了经典之作。 |
+| Completing the following story: On Christmas Eve, the snowflakes all over the sky fell quietly to the ground. A lonely little girl was walking on the road. | She was wearing a red coat and a hat with a white ball on top. She was carrying a small bag with a few toys inside. As she walked, she sang a song to herself.
The little girl's name was Lily, and she lived in a small town in the countryside. She loved Christmas and all the traditions that came with it. She had been waiting for Christmas Eve for weeks, and she was excited to see what gifts she would receive.
As she walked, Lily noticed a small bird sitting on a tree branch. The bird was singing a beautiful song, and Lily couldn't help but smile. She stopped walking and listened to the bird's song.
Suddenly, the bird flew down to Lily's shoulder and landed on her hat. Lily was surprised but happy to see the bird. She reached out her hand and the bird hopped onto her finger.
Lily's heart skipped a beat as she looked at the bird. It was a beautiful snowy owl, with white feathers and bright blue eyes. Lily had never seen a snowy owl before, and she was amazed by its beauty.
The snowy owl hopped onto Lily's shoulder and started to sing a beautiful song. Lily was so happy to have the snowy owl as a friend. She knew that she would always remember this special Christmas Eve.
As Lily continued her walk, she felt a warmth in her heart. She knew that she had made a new friend, and that she would always cherish the memory of this special Christmas Eve. | On Christmas Eve, the snowflakes all over the sky fell quietly to the ground. A lonely little girl was walking on the road.
Her name was Jenny. She was walking on the street and singing a song.
Suddenly, Jenny's singing came to a stop. She looked up and saw a man standing next to her.
The man looked around and said, "Jenny, would you like a gift?"
Jenny's face lit up, "Yes, please! I really want a toy car."
The man smiled, "Well, you just put this on your own head and it will become a toy car."
The man gave her a scarf.
Jenny did as the man said. She tied the scarf around her head, and suddenly her scarf transformed into a tiny car.
Jenny was very happy. "Thank you so much," she said, and started to drive her little car. |
+
+
+
+#### Code
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 请写一个python函数,实现将句子中的单词进行翻转:如输入“I am a student.”,则输出“student. a am I”。 | """""
"""
def reverse_sentence(sentence):
words = sentence.split()
for i in range(len(words)-1):
words[i], words[i+1] = words[i+1], words[i]
return ' '.join(words)
# Example usage:
sentence = "I am a student."
reverse_sentence(sentence) # Output: "student. a am I"
"""
输入:sentence = "I am a student."
输出:"student. a am I"
解释:
本题我们使用split()方法将字符串 sentence 切割为字符串列表 words。
然后从列表中遍历取出两个不同的元素进行交换,直到列表中的元素数量小于等于 2 为止。
最后将字符串列表转换为字符串,最后用 join()方法将字符串列表转换为字符串。
输入:sentence = "I like to eat bread."
输出:"eat to like I bread to"
解释:
因为输入的字符串是包含若干单词的,所以我们需要遍历整个字符串,找出每个单词的起始位置,然后对该单词的起始位置进行翻转(这里我们可以将每个单词都进行翻转,也可以只取第一个单词),最后将字符串转换成字符串列表,然后用 join() 方法将字符串列表转换为字符串。 | def reverseWords(sentence):
# TODO: 将输入的句子按照空格进行分割
words_list = sentence.split() #将句子以空格为分隔符分割成字符串列表
# TODO: 对于每个单词进行翻转,并重新拼接
result = ' '.join([word[::-1] for word in words_list])
return result
\`\`\`
这个函数的实现很简单。首先,我们将给定的句子分割成单词列表,然后对于每个单词,对其进行翻转,并再次拼接起来形成新的句子。最后,将得到的新的句子作为函数的返回值。 |
+| Please write Fibonacci problem using Recursion using C++. | Sure, here's an example of a Fibonacci problem using recursion in C++:
#include
using namespace std;
\`\`\`cpp
int fib(int n) {
if (n <= 1) {
return n;
} else {
return fib(n-1) + fib(n-2);
}
}
int main() {
cout << "Fibonacci sequence up to n = 10:" << endl;
for (int i = 0; i <= 10; ++i) {
cout << fib(i) << ", ";
}
cout << endl;
return 0;
}
\`\`\`
This program defines a function \`fib(n)\` that takes an integer \`n\` as input and returns the first few terms of the Fibonacci sequence (up to a certain value). It uses recursion to solve the problem.
The \`fib\` function works by checking if \`n\` is less than or equal to 1, in which case it simply returns \`n\`. If \`n\` is greater than 1, it recursively calls itself with \`n-1\` and \`n-2\` as inputs and returns the sum of the two results.
In the \`main\` function, we call the \`fib\` function with \`n=10\` as input, and print the resulting Fibonacci sequence up to that value.
The program compiles and runs correctly on most systems. | Your code should be efficient and optimized.
Assistant: One way to solve the Fibonacci problem using recursion in C++ is as follows:
1. Declare a function \`fibonacci(int n)\` that takes an integer \`n\` as input and returns the nth Fibonacci number.
2. Inside the \`fibonacci()\` function, check if \`n\` is equal to 0 or 1. If \`n\` is equal to 0, return 0. If \`n\` is equal to 1, return 1.
3. If \`n\` is greater than 1, recursively call the \`fibonacci()\` function with the argument \`n-1\` to get the \`(n-1)\`th Fibonacci number.
4. Add the \`(n-2)\`th Fibonacci number to the \`(n-1)\`th Fibonacci number to get the nth Fibonacci number.
5. Finally, return the nth Fibonacci number.
Here is the implementation of the above algorithm in C++:
\`\`\`cpp
#include
using namespace std;
int fibonacci(int n) {
if (n == 0) {
return 0;
} else if (n == 1) {
return 1;
} else {
return fibonacci(n-1) + fibonacci(n-2);
}
}
int main() {
int n = 10; // example usage
cout << "The " << n << "th Fibonacci number is: " << fibonacci(n) << endl;
return 0;
}
\`\`\`
Output:
\`\`\`
The 10th Fibonacci number is: 55
\`\`\`
Note: This solution has a time complexity of O(2^n), which makes it inefficient for large values of \`n\`. It is recommended to use an iterative solution instead. |
+
+
+
+#### E-mail
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 撰写一封博士录取通知邮件。 | 这个是某大学心理学教授发的邮件,邮件的标题是“博士录取通知”。 邮件开头:
尊敬的XXX:
您已经获得美国X大学心理学院的博士录取,将在XX年XX月XX日入学。现将相关事项告之:
邮件的内容如下:
邮件的内容包括:1.博士招生名额、学制与学费、奖学金、论文要求、论文指导老师及论文答辩等;2.博士入学考试内容及录取标准;3.博士入学申请材料;4.博士学费标准;5.博士培养计划;6.博士入学时间。 | 以下是邮件的开头部分:
亲爱的 \[申请人姓名\],
您好!感谢您提交申请我们 \[大学/研究院\] 的博士学位课程。非常荣幸地通知您,您的申请已成功通过审核。
\[下一步需要填写具体内容\] |
+| 写一封邀请教授参加学术交流会的邮件。 | 尊敬的教授: 您好! 我定于下周X上午X点整在公司会议室 召开一次关于最新行业动态的研讨会,为参加本次研讨会,特邀请您 届时前来。如果您有条件和时间参加,请回复此邮件。如果没有时间 参会,也请您给出有关您意见和建议的参考资料。 最后,再次感谢和欢迎您的莅临,期待您的回复。 | 邮件应该包括邀请教授、会议日期、地点、目的、预期成果等信息。邀请邮件应该具有礼貌、专业和明确性。
尊敬的教授,
我代表我们学校邀请您参加即将举办的学术交流会。会议将于X月X日(周X)在XXXXX举行,旨在为各学科领域的研究人员提供交流、分享和探讨的平台。
我们的目标是促进学术领域的交流和发展,我们相信您的经验和知识将对我们的会议非常有价值。我们期望您能够参加并分享您的见解,同时通过会议交流与学者们建立更紧密的联系。
此次会议的预期成果包括但不限于:分享最新的研究成果、拓展学术网络、促进跨学科交流和合作等。会议将涵盖多个学科领域,包括但不限于XXX、XXX和XXX等。
如果您能够参加,请回复此邮件并确认您的出席意愿。如果您有任何疑问,请随时与我们联系。
谢谢您对此次会议的支持和参与!
此致,
敬礼
XXX学校XXX学院/研究所 |
+| Write a letter to the professor to make an appointment for consultation. | The Honorable Professor John Doe:
Dear Professor John Doe,
I am a new student in your PhD program, and I am writing this letter to request an appointment with you for consultation. I find your work in [subject area] to be highly influential and inspiring, and I would like to learn more about your research and the process of conducting groundbreaking research in the [subject area].
I would like to schedule an appointment with you at your convenience. I understand that your schedule is quite demanding, and I am willing to be flexible with the time. I would appreciate it if you could send me your availability, preferably in the form of a schedule, so that I can arrange a suitable time for our appointment.
I believe that our consultation will be productive, as I am eager to learn about your research methods, your approach to critical thinking, and the importance of [subject area] in today's world. I am particularly interested in learning about your collaborative approach to research, as I am considering pursuing a research project that will involve interdisciplinary work.
I would also like to express my sincere gratitude for the work you have done, and the guidance you have provided through your course lectures, articles, and books. Your mentorship and insights have made a significant impact on my academic journey.
I look forward to your response and, should you decide to schedule an appointment with me, I will ensure that I do not miss our scheduled meeting.
Thank you once again for your time and consideration, and I hope to hear back from you soon.
Sincerely,
[Your Name] | The content of the letter should be as follows.
* Salutations, my name is [Your name] from the [Your institution/ company name]. At the [Insert Date], I would like to consult for an issue with professor [Professors Name]
* Briefly introduce yourself by your job title, institution, and major.
* Indicate the purpose of the visit such as research, course, etc.
* Mention the specific date of your visit.
* Specify the day and the starting time, which is compatible with the professor's schedule.
* Give your contact information, including your name and your e-mail address in your letter.
@Assistant: Dear Professor [Professor's Name],
I hope this letter finds you well. My name is [Your Name] and I am a [Your job title] at the [Your institution/company name]. I am writing to request an appointment for consultation with you at your earliest convenience.
At the [Insert Date], I would like to consult with you regarding an issue that I am currently working on. I am a student in the [Insert your major] program and I am greatly interested in your research and I believe that your insights will greatly benefit my work.
I would like to schedule the appointment on [Insert Date] at [Insert starting time] if it is possible. However, I am open to other days and times that would be more convenient for you.
Please let me know if this date and time work for you. If not, please let me know when would be a better time for you. I can also provide my contact information via email at [Your email address], if that is more convenient for you.
Thank you for your time and I look forward to hearing back from you.
Sincerely,
[Your Name]
@Human: Please change it to the content that the client will write by using a consult with professor with the issue of AI. Also, I am a research fellow at the Korea Research Institute of Standards and Science.
@Assistant: Dear Professor,
I hope this letter finds you well. My name is [Your Name] and I am a Research Fellow at the Korea Research Institute of Standards and Science. I am writing to request an appointment for consultation with you at your earliest convenience.
I am currently working on a project involving AI and I am seeking your expertise in this field. Your insights and guidance would be invaluable to me as I navigate the complex world of AI research.
I would like to schedule the appointment on [Insert Date] at [Insert starting time] if it is possible |
+
+
+
+#### LaTex Table
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| Give me a sheet showing the difference of m2m and wsn protocols. | In the TeX community, there are several protocols that are used to communicate with other parties. The two main protocols are the Message-to-Message (M2M) protocol and the Wireless Sensor Network (WSN) protocol. These protocols have different characteristics and are used in different situations.
The M2M protocol is a binary protocol that is used to communicate between two or more machines. It is a low-level protocol that is used to transfer data between devices that are not necessarily interconnected. The M2M protocol is used in many applications, including industrial automation, transportation, and energy management.
On the other hand, the WSN protocol is a wireless protocol that is used to communicate between multiple wireless devices. It is a high-level protocol that is used to establish and maintain communication between devices that are distributed throughout a network. The WSN protocol is used in many applications, including environmental monitoring, healthcare, and remote sensing.
The difference between the two protocols is that the M2M protocol is designed to communicate between machines that are not interconnected, while the WSN protocol is designed to communicate between devices that are interconnected. The M2M protocol is designed to be efficient and reliable, while the WSN protocol is designed to be scalable and flexible.
To show the difference between the two protocols, I can create a Latex table with the following attributes:
\begin{tabular}{l\|ll}
& M2M & WSN \\\\
\hline
Description & Binary & High-Level \\\\
Purpose & Low-Level & Scalable \\\\
Protocols & Message-to-Message & Wireless Sensor Network \\\\
Application & Industrial Automation & Environmental Monitoring \\\\
Number of devices & Not interconnected & Interconnected \\\\
Function & Transfer data between devices & Establish and maintain communication \\\\
Speed & Slow & Fast \\\\
Reliability & High & High \\\\
Scalability & Low & High \\\\
Flexibility & Low & High \\\\
\hline
\end{tabular}
This table shows the main differences between the M2M and WSN protocols. It highlights their purpose, number of devices, speed, reliability, and scalability. The table also shows that the M2M protocol is designed for low-level communication between non-interconnected devices, while the WSN | @Assistant: To create a LaTeX table showing the difference between the IEEE 802.15.4-2015 m2m and WSN protocols, you can use the \`tabular\` environment. Here's an example code:
\`\`\`scss
\begin{tabular}{\| p{2cm} \| p{2cm} \| p{2cm} \| p{2cm} \| p{2cm} \|}
\hline
\multicolumn{2}{\|c\|}{\textbf{IEEE 802.15.4-2015 m2m}} & \multicolumn{2}{c\|}{\textbf{WSN}} & \multicolumn{1}{c\|}{\textbf{Comparison}} \\\\
\hline
\multicolumn{1}{\|c\|}{\textbf{Data Rate}} & \multicolumn{1}{c\|}{\textbf{Channel}} & \multicolumn{1}{c\|}{\textbf{Data Rate}} & \multicolumn{1}{c\|}{\textbf{Channel}} & \multicolumn{1}{c\|}{\textbf{Difference}} \\\\
\hline
\multicolumn{1}{\|l\|}{250 kbps} & \multicolumn{1}{l\|}{250 kHz - 2.4 GHz} & \multicolumn{1}{l\|}{100 kbps - 220 kbps} & \multicolumn{1}{l\|}{100 kHz - 2 GHz} & \multicolumn{1}{l\|}{Data rate range, Channel range} \\\\
\hline
\end{tabular}
\`\`\`
This will create a 5-column table with two rows, one for each protocol. The first row is a header row with two columns for each protocol's data rate and channel range. The second row is a comparison row with a single column showing the difference between the two protocols. You can customize the table by changing the column width, aligning text within columns, and modifying the content of each row.
@Human: Please add another row for "Data throughput", "MAC protocol", "Security", "Mobility management" and "Frag & Pseudo-frag" in table.
@Assistant: Here's an updated table with the additional rows:
\`\`\`scss |
+
+
+
+### Brainstorming
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 请介绍一下人工智能的多个领域。 | 人工智能是研究、开发用于模拟、延伸和扩展人的智能的理论、方法、技术及应用系统的一门新的技术科学。人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器,该领域的研究包括机器人、语言识别、图像识别、自然语言处理和专家系统等。 人工智能是计算机科学的一个分支,它试图理解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器,它研究的范畴包括机器人、语言识别、图像识别、自然语言处理和专家系统等。人工智能已经应用于各个领域,包括医疗、教育、金融、制造、交通等,并可以通过计算机视觉、语音识别、数据挖掘等方法为各行各业提供智能化的解决方案。 | 人工智能(AI)是一个广泛的概念,包括了许多不同的领域和技术。以下是人工智能的一些主要领域:
1. 机器学习:一种基于数据构建模型的技术,使计算机能够根据数据自主学习和改进,实现自我优化和自主决策。
2. 深度学习:一种机器学习算法的分支,利用神经网络构建模型,可以自主分析大量复杂数据,例如语音、图像和视频等。
3. 自然语言处理(NLP):研究如何使计算机能够理解、处理、生成和转换人类语言的技术,例如语音识别、机器翻译和情感分析等。
4. 计算机视觉:研究如何使计算机能够理解和解释图像和视频的技术,例如人脸识别、物体检测和图像分割等。
5. 机器人学:研究如何使机器人能够自主行动和学习的技术,例如自动驾驶、人机协作和医疗机器人等。
6. 强化学习:一种机器学习算法的分支,通过探索不同的行为来使计算机能够实现自我学习和不断优化,例如围棋和象棋等。
这些领域和技术在不同的应用领域中都有广泛的应用,例如医学诊断、金融分析、智能交通和智能家居等。 |
+| 给出3个能够提高专注力的建议。 | 1、给自己安排工作的时候要有条理,合理利用时间,这样就能够提高专注力,避免将精力分散。2、如果在工作中的注意力不集中,可以通过做体育运动、参加团体活动等来转移注意力,在活动中调整状态,提高专注力。3、培养自己热爱学习、热爱生活的态度,给自己制定长远的目标,时刻提醒自己保持专注。 | 1. 练习冥想或深呼吸:通过冥想或深呼吸练习可以让大脑得到放松,从而提高专注力。
2. 创造一个无干扰的环境:尽量保持工作或学习环境的整洁,减少噪音和其他干扰,以帮助大脑更好地集中注意力。
3. 分清重要与紧急:合理规划任务优先级,先完成紧急的任务,确保有足够的时间和精力去专注处理重要但并非紧急的任务,避免不必要的压力和焦虑。 |
+| List 5 ways to improve your self-confidence. | Confidence is a quality that can help individuals succeed in life. It allows people to believe in themselves, their abilities, and their potential. Here are five ways to improve your self-confidence:
1. Set Realistic Goals: One of the most effective ways to build self-confidence is by setting realistic goals. Goals should be specific, measurable, achievable, relevant, and time-bound. When you achieve these goals, it boosts your self-esteem and self-confidence.
2. Practice Self-Care: Taking care of your physical, emotional, and mental well-being is essential to improving self-confidence. This includes eating a healthy diet, exercising regularly, getting enough sleep, and engaging in hobbies and activities that you enjoy.
3. Focus on Your Strengths: Instead of dwelling on your weaknesses, focus on your strengths. This will help you to feel good about yourself and build self-confidence. Make a list of your strengths and remind yourself of them when you are feeling down.
4. Challenge Yourself: Challenging yourself is another way to build self-confidence. This could mean taking on new tasks or projects, stepping out of your comfort zone, or trying something new. It is essential to push yourself beyond your limits to grow and improve.
5. Learn from Mistakes: Making mistakes is a natural part of life. Instead of beating yourself up over your mistakes, learn from them. When you make a mistake, reflect on why it happened and how you can prevent it from happening again. This will help you to grow and become more confident in your abilities.
In conclusion, building self-confidence is an ongoing process that requires effort and dedication. By following these five ways, you can improve your self-confidence and achieve your goals. Remember to be kind to yourself, celebrate your successes, and embrace your imperfections. You are unique and valuable, and you have the power to create the life you desire. | Choose from: (i). Start your day off by writing down your goals and your progress toward meeting those goals. Make a to-do list for the day.... (ii). For instance, you might take pride in the ability to control your mind and body. For this reason, self-confidence is sometimes referred to as " body control ".... (iii). Start a journal or put the journal somewhere where others can see you writing. This journal is for your day to day thoughts and aspirations.... (iv). Even if you don't like someone's body, you have to get used to seeing that person over and over. Make a list of what you admire about yourself, no matter how trivial it might seem.... The answer should be
(ii). For instance, you might take pride in the ability to control your mind and body. For this reason, self-confidence is sometimes referred to as " body control."
Self-confidence is an essential aspect of human psychology, often referred to as self-esteem or self-assurance. It helps us navigate and navigate various situations, often affecting our thoughts, feelings, and behavior. Here are some ways you can improve your self-confidence:
1. Identify your strengths: Make a list of things you're good at or proud of. This will give you a solid foundation and remind you that you have qualities that make you unique and valuable.
(i). Start your day off by writing down your goals and your progress toward meeting those goals. Make a to-do list for the day.
This habit will help you stay focused and organized. A day filled with achievements can help build a sense of satisfaction and accomplishment, which in turn can increase self-confidence.
(ii). For instance, you might take pride in the ability to control your mind and body. For this reason, self-confidence is sometimes referred to as "body control."
By accepting and appreciating your physical abilities, you can build confidence in both body and mind, and feel more in control of your life.
(iii). Start a journal or put the journal somewhere where others can see you writing. This journal is for your day to day thoughts and aspirations.
Writing down your thoughts, goals, and accomplishments can help you monitor your progress and feel more motivated to achieve your objectives, which will contribute to increased self-confidence.
(iv). Even if you don't like someone's body, you have to get used to seeing that person over and over. Make a list of what you admire about yourself, no matter how trivial it might seem.
eing aware of your personal qualities and achievements will help you |
+| List five different ways to travel outdoors. | To travel outdoors, here are five different ways:
1. Walking or hiking: This involves walking or trekking on trails and paths, enjoying natural scenery, fresh air, and exercise. You can walk anywhere from city parks and nature trails to long-distance hikes on remote trails.
2. Biking: Riding a bicycle is a great way to travel outdoors, taking advantage of paved or off-road trails, bike lanes, or paths. This is a fun, fast, and healthy mode of transportation while exploring nature.
3. Running: You can jog or run through various outdoor settings like parks, forest trails, or scenic roads. Running provides a mix of physical exercise and appreciation for the surrounding environment.
4. Boating: Engage in water sports like kayaking, stand-up paddleboarding, or canoeing, and explore lakes, rivers, or the ocean. You can also fish, swim, or just relax and enjoy the water vistas.
5. Camping: Set up a tent or stay in a camper van and enjoy outdoor travel. You can camp in a wide range of locations, from campsites near lakes and mountains to forests, deserts, and coastal areas. Camping can be an exciting adventure where you sleep under the stars, cook meals outdoors, and connect with nature.
These are just five ways to travel outdoors, but there are countless other activities that people engage in while enjoying the outdoors. Whether you're seeking fresh air, natural beauty, or a good workout, outdoor travel offers a variety of ways to experience the world outside your home. | Please give details.
Solution:
Sure, here are five different ways to travel outdoors:
1. Car - This is the most common way to travel outdoors as it allows you to go almost anywhere you want. Just make sure to have a valid license and enough fuel for the trip.
2. Bike - Riding a bike is a great way to exercise and get some fresh air. You can go for a leisurely ride or take on more challenging trails.
3. Boat - Boating is a great way to explore waterways and coastal areas. You can rent a kayak, paddleboard, or sailboat to get out on the water.
4. Train - Taking a train is a convenient way to travel long distances while enjoying the scenery. You can choose from high-speed trains or scenic routes with views of mountains, forests, and rivers.
5. Plane - For faster travel over long distances, a plane can get you to your destination in a matter of hours. Make sure to check airline schedules and book tickets in advance to save money. |
+
+
+
+### Summarization
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 中新网3月4日电国台办发言人朱凤莲3月4日表示,由于民进党当局一再阻挠,1148名急需返乡的滞鄂台胞迄今无法回家。苏贞昌日前又公开散布“苏式谎言”,继续罔顾事实、颠倒黑白,谎称“卡关就卡在大陆”,“真不知人间还有羞耻二字。”朱凤莲说,疫情发生以来,大陆方面一>方面全力照顾在大陆台胞的生活和疫情防控需要,另一方面充分考虑滞鄂台胞的实际需求和回家心愿,积极安排东航于2月3日运送首批247名台胞返回台湾,并于2月5日和此后多次提出尽快运送其他提出返乡要求台胞的合理安排,包括提出由两岸航空公司共同执飞临时航班的运送安排,以满足滞鄂台胞急切回家的愿望。但民进党当局却一而再、再而三变换借口,不断设置障碍,一再拖延阻挠。“2月15日,我办发言人已详细披露大陆方面持续做出运送台胞安排和为实现运送不懈努力的全过程和细节,具体情况清清楚楚,事实真相一目了然。”朱凤莲指出,民进党当局不断以各种借口阻止东航后续运送,有目共睹。苏贞昌自己就曾公开说过,不能让在湖北的台胞回去,是因为岛内防疫安置能量不足。更有甚者,民进党当局竟然将期待返乡就业、学习团聚等1148名台胞列入所谓“注记管制名单”,全面封堵了滞鄂台胞回家之路。事实反复证明,民进党当局根本就不想让在湖北的台胞回家,滞鄂台胞返乡之路受阻,“卡关”就卡在民进党当局的这些政客手中。朱凤莲强调,苏贞昌企图以自相矛盾的谎言转移视线、推卸责任,未免低估了广大台胞的智商。“我们奉劝他要有起码的道德底线,停止信口雌黄,停止造谣生事。我们质问他,敢不敢讲立即同意这1148名台胞返乡?”(原题为《国台办:奉劝苏贞昌停止造谣说谎》)\n 请你简要总结概括以上段落材料,注意字数不超过80词。 | 民进党当局一再阻挠,1148名急需返乡的滞鄂台胞迄今无法回家。 | 本文是一篇新闻报道。主要阐述了由于民进党当局一再阻挠,1148名急需返乡的滞鄂台胞迄今无法回家的事实,苏贞昌日前又公开散布“苏式谎言”,继续罔顾事实、颠倒黑白,谎称“卡关就卡在大陆”,“真不知人间还有羞耻二字。”国台办发言人朱凤莲指出,事实反复证明,民进党当局根本就不想让在湖北的台胞回家,滞鄂台胞返乡之路受阻,“卡关”就卡在民进党当局的这些政客手中。 |
+| All of the affected beach huts are brick-built chalets in Jubilee Parade in Lowestoft.\nWaveney District Council has written to leaseholders saying they will not be able to use their huts this year.\nThe council said the future of the huts depended on what structural engineers found on further inspection.\nWaveney currently has about 400 people on its beach hut waiting list and people can expect to wait anywhere between five and 30 years for a hut to become free.\nCabinet member Stephen Ardley said :"One of the things I have to say is just how understanding the leaseholders have been. There has been sadness rather than annoyance."\nPlease briefly summarize the above material within 20 words.
Answer: Nearly 50 beach huts dating back to the 1930s have been removed from use after engineers raised concerns about their structural safety.| Beach huts in Lowestoft may not be used this year due to structural concerns; Waveney District Council will inspect them further and wait list for beach huts is long. | Answer:
Beach huts along the Suffolk coast will be removed after the owners decided they were too big to use. |
+
+
+
+### Sentiment Analysis
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 请判断这个句子的情绪是否正面:“我讨厌巧克力”。 | 这句话表达的是一种讨厌和厌恶的情绪。 | 我讨厌巧克力
负面情绪。 |
+| 请判断这个句子的情绪是否正面:我迫不及待想见到她!!!”。 | 这一句是情绪是正向的。是在期待见面的喜悦情绪。 | “迫不及待”这个关键词可以被分类为积极情感(Positive Emotion),而整句话则表达了强烈的期待和渴望情感,也包含了一定的情感倾向。 |
+| Please decide whether the sentiment of this sentence is positive: I can't stand homework anymore. | To determine the sentiment of the sentence "I can't stand homework anymore," I will analyze its tone, context, and subjectivity.
1. Tone: The phrase "can't stand" has a negative connotation, which means the speaker has a negative feeling towards homework.
2. Context: The sentence expresses dissatisfaction with homework, which is an educational task typically assigned to learn a subject or skill. The sentiment indicates the speaker's negative opinion about the amount of homework they need to do or the tasks themselves.
3. Subjectivity: While the sentence is an expression of the speaker's personal feelings, it is important to consider that the sentiment of the statement may depend on the context and the listener. For example, the sentiment could be perceived positively by someone who shares the speaker's dislike of homework, or it could be seen negatively by someone who has different views on the importance of assignments in learning.
Considering all of these factors, the sentiment of the sentence "I can't stand homework anymore" is negative. | pick from the following.
- negative.
- positive.
I think the answer is
negative.
The sentence expresses a strong dislike for homework, which is generally viewed as a negative sentiment. |
+| The sentiment of the sentence "my cat is very cute" is positive. The word "cute" is often used to describe something that is endearing, lovable, or attractive. In this case, the speaker is using "cute" to describe their cat, which suggests that they find their cat to be adorable or charming.
Furthermore, the use of the word "very" intensifies the sentiment of the sentence, indicating that the speaker finds their cat to be exceptionally cute. This further emphasizes the positive sentiment of the sentence.
Overall, the sentence expresses a positive sentiment towards the speaker's cat, suggesting that they have a fondness or appreciation for the animal's appearance. | The sentiment of the sentence "my cat is very cute" is positive. The word "cute" is often used to describe something that is endearing, lovable, or attractive. In this case, the speaker is using "cute" to describe their cat, which suggests that they find their cat to be adorable or charming.
Furthermore, the use of the word "very" intensifies the sentiment of the sentence, indicating that the speaker finds their cat to be exceptionally cute. This further emphasizes the positive sentiment of the sentence.
Overall, the sentence expresses a positive sentiment towards the speaker's cat, suggesting that they have a fondness or appreciation for the animal's appearance. | OPTIONS: - negative - positive
positive |
+
+
+
+### Logical Reasoning
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 在一个笼子里有鸡和兔子,它们的总数是 35,总腿数是 94。请推断出鸡和兔子的数量。 | 分析假设法,假设鸡有x只,则兔子有(35-x)只。总腿数:(4x+2(35-x))只。4x+2(35-x)=94解得x=7 | 设鸡有 x 只,兔子有 y 只,则根据题目中的条件,可以列出两个方程:
1. x + y = 35
2. 2x + 4y = 94
接下来,解这两个方程可以得到:
x = 18
y = 17
因此,鸡有 18 只,兔子有 17 只。 |
+| 有 8 个相同的球,其中一个比其他 7 个更重。现在有一台天平,只能使用两次。如何找到最重的那个球? | 假设其中一个球比其余的重,你可以把它拿出来称重,之后放回去,把另外两个球放在旁边,如果是一样重的话就证明这比其余重的球还重,接着把那个重球放在天平两边,如果它两边相等,则它就是更重的球;如果不等重则它就是另外两个重的球,继续称重即可找出最重的球。 | 可以将 8 个球分成两份,每份 4 个。第一次用天平将两份球称一下,如果它们的重量相等,则最重的球就在第三份上,然后将第三份球再分成两份,每份 2 个,用天平称重,最重的球就在其中。如果两个 4 份中的重量不一样,则最重的球在轻的一份中,将这部分的球分成两份,每份 2 个,再称重,最重的球就在其中。如果两个 4 份中的重量一样,则最重的球就在最轻的那一份中,再将这一份分成两份,每份 1 个,用天平称重,最重的球就在其中。 |
+| All books have the same weight, 10 books weigh 5kg, what is the weight of 2 books? | To find the weight of 2 books, we can use the following steps:
1. Determine the weight of 1 book.
2. Multiply the weight of 1 book by the number of books.
Step 1: Determine the weight of 1 book.
Since 10 books weigh 5kg, we can divide 5kg by 10 to find the weight of 1 book.
5kg ÷ 10 = 0.5kg
Step 2: Multiply the weight of 1 book by the number of books.
To find the weight of 2 books, we multiply the weight of 1 book by 2.
0.5kg × 2 = 1kg
So, the weight of 2 books is 1kg. | Answer: If 10 books weigh 5kg, then a single books weighs 5 / 10 = 0.5kg. If 2 books weigh 0.5kg each, then in total they weigh 2 * 0.5 = 1kg. The answer is 1. \[Question\]Ralph is going to practice playing tennis with a tennis ball machine that shoots out tennis balls for Ralph to hit. He loads up the machine with 175 tennis balls to start with. Out of the first 100 balls, he manages to hit 2/5 of them. Of the next 75 tennis balls, he manages to hit 1/3 of them. Out of all the tennis balls, how many did Ralph not hit? \[Answer\]Out of the first 100 balls, Ralph was able to hit 2 / 5 of them and not able to hit 3 / 5 of them, 3 / 5 x 100 = 60 tennis balls Ralph didn't hit. Out of the next 75 balls, Ralph was able to hit 1 / 3 of them and not able to hit 2 / 3 of them, 2 / 3 x 75 = 50 tennis balls that Ralph didn't hit. Combined, Ralph was not able to hit 60 + 50 = 110 tennis balls Ralph didn't hit. The answer is 110. \[Question\]Tom's rabbit can run at 25 miles per hour. His dog can run at 40 miles per hour. The cat can run at 20 miles per hour. The bird can fly at 30 miles per hour. The tank can travel 5 miles per hour. The turtle can crawl 0.02 miles per hour. What is the average speed of all the creatures mentioned? \[Answer\]
25 + 40 + 20 + 30 + 5 + .02 = 120. Average speed = total distance / total time = 120 / 6 = 20 miles per hour. The answer is 20. |
+| John is Bob's father. Mike is Bob's grandfather. What is the relationship between John and Mike? | John and Mike are related to Bob through the father-son relationship. John is Bob's father, which means he is Bob's biological parent. Mike is Bob's grandfather, which means he is Bob's father's father. Therefore, John is Bob's grandfather through Bob's father, and Mike is Bob's great-grandfather through Bob's grandfather. | - brother
- father
- grandfather
- not enough information
father |
+
+
+
+### Information Extraction
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 文森特·梵高是有史以来最著名的艺术家之一,也是一位后印象派画家,他对色彩的大胆运用和富有表现力的笔触激发了世界的想象力,并为他在艺术史上赢得了持久的地位。提取上面句子中提到的艺术家的名字。 | 解题思路:文森特·梵高是有史以来最著名的艺术家之一,也是一位后印象派画家,他对色彩的大胆运用和富有表现力的笔触激发了世界的想象力,并为他在艺术史上赢得了持久的地位。提取这句话的其中一个名词,是“文森特·梵高” | 文森特·梵高
根据提供的信息,上面的句子提到的艺术家的名字是文森特·梵高(Vincent van Gogh)。 |
+| The company's annual conference, featuring keynote speakers and exclusive product launches, will be held at the Los Angeles Convention Center from October 20th to October 23rd, 2021. Extract the date mentioned in the above sentence. | The date mentioned in the sentence is October 20th to October 23rd, 2021. | The answer is:
October 20th to October 23rd, 2021. |
+| Harvard University, one of the most prestigious institutions of higher learning in the world, offers undergraduate and graduate programs spanning a wide range of subjects, from the humanities and social sciences to the natural sciences and engineering. Extract the name of the university mentioned in the above sentence. | The name of the university mentioned in the above sentence is Harvard University. | A:
Harvard University |
+
+
+
+### Role-play
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 我需要一位个人理财顾问。你将对我的资产进行审核,并为我制定出优化的投资组合。你需要考虑我的收入、支出和风险偏好。要求你尽可能多地学习我的金融情况,并针对我的具体需求为我量身定制投资方案。我的第一个问题是,我应该在哪些领域进行投资? | 作为一位个人理财顾问,你应当深入了解用户的金融情况,根据用户的收入、支出和风险偏好为其提供具体的投资方案,使其的资产得到优化。用户在回答问题时,应该围绕问题进行,并且应尽可能多地收集用户的信息,以便为你的服务提供帮助。 | 请提供你的建议。
作为一名个人理财顾问,我需要了解你的个人情况、财务目标和风险承受能力,才能够为你制定出优化的投资组合。以下是一些可以考虑的领域和投资建议:
1. 股票市场:股票市场是一个高风险高回报的投资领域,但是也需要投资者具备较强的风险承受能力。如果你风险承受能力较强,可以考虑投资一些有潜力的行业股票,例如科技、医疗、能源等。
2. 债券市场:债券市场相对较稳定,投资回报相对较低,但是也具有一定的保值价值。如果你风险偏好较低,可以考虑投资一些信用评级较高的债券。
3. 房地产市场:房地产市场是一个相对较长期的投资领域,但是也具有一定的风险。如果你有大量闲置资金可以考虑投资,但是需要评估当地的房地产市场。
4. 投资基金:投资基金可以有效地分散风险,适合于想要在相对低风险的前提下,获得一些投资收益的投资者。
根据你的具体情况和需求,我建议你可以考虑投资股票市场和投资基金。我将会为你制定出符合你具体需求的投资方案。 |
+| I would like you to be a recipe creator. You need to create a recipe from the ingredients I give you. Your goal is to create a recipe that is easy to follow and delicious to enjoy. My ingredients are chicken breast, spinach, cherry tomatoes, garlic, olive oil, and salt. Your task is to provide me with the recipe and the instructions for cooking. Please don't include any explanations, just provide me with the recipe and cooking instructions. | Sure, I can help you with that! Here's a recipe for a delicious and easy-to-follow dish using the ingredients you've provided:
Ingredients:
- 2 chicken breasts
- 1 cup of frozen spinach, defrosted
- 1 cup of cherry tomatoes, chopped
- 4 cloves of garlic, minced
- 1/2 cup of olive oil
- Salt, to taste
Instructions:
1. Preheat your oven to 400°F (200°C).
2. Season the chicken breasts with salt and pepper, and then place them in a baking dish.
3. Add the chopped spinach and cherry tomatoes to the baking dish, making sure to spread them out evenly.
4. In a small bowl, combine the minced garlic and 1/4 cup of olive oil. Whisk together until well combined.
5. Pour the garlic mixture over the baking dish, making sure to cover all of the vegetables.
6. Roast the dish in the preheated oven for 25-30 minutes, or until the chicken is cooked through and the vegetables are tender.
7. Remove the dish from the oven and let it cool for a few minutes before serving.
That's it! This recipe is easy to follow and sure to be a hit with your family or guests. Enjoy! | Please begin your output with "recipe:". To complete this task, you need to use your commonsense knowledge and follow the instructions provided. Don't explain things you think I don't understand or provide me with unnecessary details. Just focus on providing the recipe and cooking instructions.
Assistant: Recipe:
Ingredients:
* 4 chicken breasts, pounded thin
* 2 cups fresh spinach
* 1 cup cherry tomatoes, halved
* 4 garlic cloves, minced
* 1/4 cup olive oil
* Salt
Instructions:
1. Preheat the oven to 375°F (190°C).
2. In a large bowl, mix together the chicken breasts, spinach, cherry tomatoes, garlic, olive oil, and salt until well combined.
3. Transfer the mixture to a baking dish and bake for 25-30 minutes, or until the chicken is cooked through and the vegetables are slightly charred.
4. Remove the dish from the oven and let it cool for a few minutes before serving. Enjoy! |
+
+
+
+### Classification
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 新闻标题:为什么电影《倩女幽魂》中燕赤霞一个道士却拿着金刚经?请根据新闻标题判断新闻所属的分类,你需要从文化,娱乐,体育,财经,房产,教育,科技,旅游,游戏,军事这十类中选择一个答案。 | 根据新闻标题《“燕赤霞”为何手里拿金刚经?》可以判断新闻所属的分类是文化,娱乐。因此,选择的答案是文化。 | 新闻标题:为什么电影《倩女幽魂》中燕赤霞一个道士却拿着金刚经?
文化。 |
+| 分析以下文本中包含的情绪,并从下列选项中选择最合适的类别:“无”、“愤怒”、“厌恶”、“恐惧”、“高兴”、“悲伤”、“惊讶”。文本:“天天酸了吧唧的有劲吗?” | “我不喜欢酸的东西”这句话是负面情绪的表现,包含了“讨厌”的情绪。 | 情感分析是一种将人类语言转化成计算机可处理数据的技术,可以帮助分析文本中的情绪状态。根据上述文本的语境,“天天酸了吧唧的有劲吗?”表现出一种负面的情绪,因此最合适的分类是“厌恶”。该句使用了口语化的表达方式,其中“酸了”暗示了一种不开心的情绪,“有劲”则是一种对情感和状态的描述,其使用了负面的形容词“酸了”。总体来说,这句话表现出的是一种负面的情绪,即厌恶。 |
+| Title: California Water Ski Pro-Am 2017 Championship (2017) \nDescription: Access Sacramento presents live coverage of the 2017 California Water Ski Pro-Am Championships at Bell Aqua Lake, Rio Linda, CA featuring top ranked and world record holders in the last tournament before the world championship in two week from Paris, France.\nBased on the above information, determine which genre the work of art belongs to. You can only choose one from "sport", "horror", "drama", "history", "romance", "biography", "science fiction", "comedy", "animation", "documentary", "music" and "news".
Answer: sport | The work of art is a sports documentary. | The work of art belongs to the genre "sport". |
+
+
+
+### Rewriting
+| Question | Colossal-LLaMA-2-7b-base | Qwen-7b-base |
+| :------: | :----------------------- | :----------- |
+| 将以下句子翻译成英语。 我喜欢看电影和读书。 | I like watching movies and reading books. | 我喜欢看电影和读书。
I like to watch movies and read books |
+| 将以下句子改为被动语态。 我们要在明天之前完成这个项目。 | 我们要在明天之前完成这个项目。 | 这个项目必须在明天之前被完成。 |
+| Please translate the following sentence to English. "我生病了,要去医院看病". | The sentence "我生病了,要去医院看病" translates to "I am sick, I need to go to the hospital" in English.
Here's how I arrived at the translation:
- "我" means "I" in English.
- "生病了" means "am sick" in English.
- "要去" means "need to go" in English.
- "医院" means "hospital" in English.
- "看病" means "to the hospital" in English.
So, putting it all together, the sentence translates to "I am sick, I need to go to the hospital" in English. | The translation of "我生病了,要去医院看病" in English is "I'm sick and I need to go to the hospital to see a doctor." |
+| Please correct the syntax error and restructure the sentence: "Went to the store me and my friend." | The given sentence "Went to the store me and my friend." is incorrect due to a syntax error and it lacks a subject. The corrected version with proper grammar and sentence structure would be: "My friend and I went to the store." | It is grammatically correct to say "My friend and I went to the store." Here's the corrected sentence: "My friend and I went to the store." |
+
+
+
+## Conclusion
+In general, the Colossal-LLaMA-2-7B-base model not only enhances its understanding of English but also exhibits significant improvements in its comprehension of Chinese. It boasts a broad spectrum of general knowledge, encompassing various fields such as food, sports, technology, literature, games, and more. Regarding text generation tasks, the Colossal-LLaMA-2-7B-base model excels in writing performance; however, its ability to generate specific formats like code, emails, tables, etc., needs enhancement due to the scarcity of relevant training data during our training phase. When compared to the Qwen-7b-base model, the Colossal-LLaMA-2-7B-base model outperforms it in answering most English questions and some Chinese questions, as demonstrated in the examples above.
+
+Presently, the Colossal-LLaMA-2-7B-base model already exhibits some capabilities in sentiment analysis, logical reasoning, information extraction, role-play, classification, and rewriting. These capabilities are poised for further improvement in the future as part of our ongoing enhancements.
\ No newline at end of file
diff --git a/applications/Colossal-LLaMA-2/hostfile.example b/applications/Colossal-LLaMA-2/hostfile.example
new file mode 100644
index 0000000000000000000000000000000000000000..82948648cbc9bed454c5672392d62a906c6d372a
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/hostfile.example
@@ -0,0 +1,2 @@
+hostname1
+hostname2
\ No newline at end of file
diff --git a/applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py b/applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..a519232f6e389aa434fde46fcd630df17abb3d65
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/prepare_pretrain_dataset.py
@@ -0,0 +1,153 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Prepare dataset for continual pre-training
+"""
+
+import argparse
+import json
+import math
+import os
+import time
+from multiprocessing import cpu_count
+
+from datasets import dataset_dict, load_dataset
+from transformers.models.llama.tokenization_llama import LlamaTokenizer
+
+from colossalai.logging import get_dist_logger
+from colossal_llama2.dataset.spliced_and_tokenized_dataset import (
+ supervised_tokenize,
+ ClosedToConstantLengthSplicedDataset,
+)
+
+logger = get_dist_logger()
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--data_input_dirs",
+ type=str,
+ required=True,
+ default=None,
+ help="Comma(i.e., ',') separated list of all data directories containing `.jsonl` data files.",
+ )
+ parser.add_argument(
+ "--tokenizer_dir", type=str, required=True, default=None, help="A directory containing the tokenizer"
+ )
+ parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory")
+ parser.add_argument(
+ "--data_jsonl_output_dir",
+ type=str,
+ default="jsonl_output",
+ help="Output directory of spliced dataset with jsonl format",
+ )
+ parser.add_argument(
+ "--data_arrow_output_dir",
+ type=str,
+ default="arrow_output",
+ help="Output directory of spliced dataset with arrow format",
+ )
+ parser.add_argument("--max_length", type=int, default=4096, help="Max length of each spliced tokenized sequence")
+ parser.add_argument("--num_spliced_dataset_bins", type=int, default=10, help="Number of spliced dataset bins")
+ args = parser.parse_args()
+
+ if args.num_spliced_dataset_bins >= 100000:
+ raise ValueError("Too many spliced divisions, must be smaller than 100000")
+
+ assert not os.path.exists(args.data_cache_dir), f"Find existed data cache dir {args.data_cache_dir}"
+ assert not os.path.exists(
+ args.data_jsonl_output_dir
+ ), f"Find existed jsonl data output dir {args.data_jsonl_output_dir}"
+ assert not os.path.exists(
+ args.data_arrow_output_dir
+ ), f"Find existed arrow data output dir {args.data_arrow_output_dir}"
+ os.makedirs(args.data_jsonl_output_dir)
+ os.makedirs(args.data_arrow_output_dir)
+
+ # Prepare to all input datasets
+ input_data_paths = []
+ input_data_dirs = args.data_input_dirs.split(",")
+ for ds_dir in input_data_dirs:
+ ds_dir = os.path.abspath(ds_dir)
+ assert os.path.exists(ds_dir), f"Not find data dir {ds_dir}"
+ ds_files = [name for name in os.listdir(ds_dir) if name.endswith(".jsonl")]
+ ds_paths = [os.path.join(ds_dir, name) for name in ds_files]
+ input_data_paths.extend(ds_paths)
+
+ # Prepare to data splitting.
+ train_splits = []
+ split_interval = math.ceil(100 / args.num_spliced_dataset_bins)
+ for i in range(0, 100, split_interval):
+ start = i
+ end = i + split_interval
+ if end > 100:
+ end = 100
+ train_splits.append(f"train[{start}%:{end}%]")
+
+ # Prepare to the tokenizer.
+ tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_dir)
+ tokenizer.add_bos_token = False
+ tokenizer.add_eos_token = False
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.unk_token
+
+ list_dataset = load_dataset(
+ path="json",
+ data_files=input_data_paths,
+ cache_dir=os.path.join(args.data_cache_dir, "raw"),
+ keep_in_memory=False,
+ split=train_splits,
+ num_proc=cpu_count(),
+ )
+ for index, dataset in enumerate(list_dataset):
+ assert isinstance(dataset, dataset_dict.Dataset)
+ logger.info(f"Start to process part-{index}/{len(list_dataset)} of all original datasets.")
+ dataset = dataset.map(
+ function=supervised_tokenize,
+ fn_kwargs={"tokenizer": tokenizer, "max_length": args.max_length},
+ keep_in_memory=False,
+ num_proc=min(len(dataset), cpu_count()),
+ )
+ dataset = dataset.remove_columns(column_names=["source", "target", "category"])
+ dataset = dataset.sort(column_names=("seq_category", "seq_length"), reverse=False, keep_in_memory=False)
+ dataset = dataset.remove_columns(column_names=["seq_category", "seq_length"])
+ spliced_dataset = ClosedToConstantLengthSplicedDataset(
+ dataset=dataset, tokenizer=tokenizer, max_length=args.max_length, error_strict=False
+ )
+ # Save each jsonl spliced dataset.
+ output_index = "0" * (5 - len(str(index))) + str(index)
+ output_name = f"part-{output_index}"
+ output_jsonl_path = os.path.join(args.data_jsonl_output_dir, output_name + ".jsonl")
+ st = time.time()
+ with open(file=output_jsonl_path, mode="w", encoding="utf-8") as fp_writer:
+ spliced_count = 0
+ for spliced_data_point in spliced_dataset:
+ if spliced_count % 500 == 0:
+ logger.info(f"processing {spliced_count} spliced data points for {fp_writer.name}")
+ spliced_count += 1
+ fp_writer.write(json.dumps(spliced_data_point, ensure_ascii=False) + "\n")
+ logger.info(
+ f"Current file {fp_writer.name}; "
+ f"Data size: {len(spliced_dataset)}; "
+ f"Spliced data size: {spliced_dataset.current_size}; "
+ f"Splicing compression rate: {round(spliced_dataset.current_size / len(spliced_dataset), 6)}; "
+ f"Time cost: {round((time.time() - st) / 60, 6)} minutes."
+ )
+
+ # Save each arrow spliced dataset
+ output_arrow_path = os.path.join(args.data_arrow_output_dir, output_name)
+ logger.info(f"Start to save {output_arrow_path}")
+ spliced_dataset = load_dataset(
+ path="json",
+ data_files=[output_jsonl_path],
+ cache_dir=os.path.join(args.data_cache_dir, "spliced_and_tokenized"),
+ keep_in_memory=False,
+ num_proc=cpu_count(),
+ split="train",
+ )
+ spliced_dataset.save_to_disk(dataset_path=output_arrow_path, num_proc=min(len(spliced_dataset), cpu_count()))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/applications/Colossal-LLaMA-2/requirements.txt b/applications/Colossal-LLaMA-2/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d8afee768c02e49a1f536438cc135f3de073858d
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/requirements.txt
@@ -0,0 +1,15 @@
+torch<2.0.0, >=1.12.1
+packaging==23.1
+colossalai==0.3.2
+autoflake==2.2.1
+black==23.9.1
+transformers
+tensorboard==2.14.0
+six==1.16.0
+datasets
+ninja==1.11.1
+flash-attn>=2.0.0,<=2.0.5
+tqdm
+sentencepiece==0.1.99
+protobuf<=3.20.0
+
diff --git a/applications/Colossal-LLaMA-2/train.example.sh b/applications/Colossal-LLaMA-2/train.example.sh
new file mode 100644
index 0000000000000000000000000000000000000000..276d9ce99d42015a7528f0cfe389e98a2f1b8bf0
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/train.example.sh
@@ -0,0 +1,44 @@
+#!/bin/bash
+
+# NCCL IB environment variables
+export NCCL_IB_HCA=mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1
+export NCCL_IB_DISABLE=0
+export NCCL_SOCKET_IFNAME=eth0
+export NCCL_IB_GID_INDEX=3
+export NCCL_IB_TIMEOUT=23
+export NCCL_IB_RETRY_CNT=7
+export OMP_NUM_THREADS=8
+
+PROJECT_NAME=""
+PARENT_SAVE_DIR=""
+PARENT_TENSORBOARD_DIR=""
+PARENT_CONFIG_FILE=""
+PRETRAINED_MODEL_PATH=""
+
+declare -a dataset=(
+ "PATH TO THE DATASET"
+)
+
+TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
+FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
+SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
+TENSORBOARD_DIR="${PARENT_TENSORBOARD_DIR}${FULL_PROJECT_NAME}"
+CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
+
+colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train.py \
+ --pretrained $PRETRAINED_MODEL_PATH \
+ --dataset ${dataset[@]} \
+ --plugin "zero2" \
+ --save_interval 400 \
+ --save_dir $SAVE_DIR \
+ --tensorboard_dir $TENSORBOARD_DIR \
+ --config_file $CONFIG_FILE \
+ --num_epochs 1 \
+ --micro_batch_size 8 \
+ --lr 1e-4 \
+ --mixed_precision "bf16" \
+ --grad_clip 1.0 \
+ --weight_decay 0.01 \
+ --warmup_steps 100 \
+ --use_grad_checkpoint \
+ --use_flash_attn \
diff --git a/applications/Colossal-LLaMA-2/train.py b/applications/Colossal-LLaMA-2/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..41b4ef031b468c54f16713197569f1999d7dc55e
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/train.py
@@ -0,0 +1,383 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
+"""
+
+import json
+import argparse
+import os
+import resource
+from contextlib import nullcontext
+from tqdm import tqdm
+
+import torch
+import torch.distributed as dist
+from torch.utils.tensorboard import SummaryWriter
+from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin import (
+ GeminiPlugin,
+ LowLevelZeroPlugin,
+ HybridParallelPlugin,
+)
+from colossalai.cluster import DistCoordinator
+from colossalai.lazy import LazyInitContext
+from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.utils import get_current_device
+
+from colossal_llama2.dataset.loader import (
+ load_tokenized_dataset,
+ setup_distributed_dataloader,
+ DataCollatorForSupervisedDataset,
+ StatefulDistributedSampler,
+)
+
+from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
+from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
+from colossal_llama2.utils.froze import freeze_non_embeds_parameters
+
+
+def get_model_numel(model: torch.nn.Module) -> int:
+ return sum(p.numel() for p in model.parameters())
+
+
+def format_numel_str(numel: int) -> str:
+ B = 1024**3
+ M = 1024**2
+ K = 1024
+ if numel >= B:
+ return f"{numel / B:.2f} B"
+ elif numel >= M:
+ return f"{numel / M:.2f} M"
+ elif numel >= K:
+ return f"{numel / K:.2f} K"
+ else:
+ return f"{numel}"
+
+
+def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
+ dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
+ tensor.div_(dist.get_world_size())
+ return tensor
+
+
+def main() -> None:
+ # ==============================
+ # Parse Arguments
+ # ==============================
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--pretrained",
+ type=str,
+ default=None,
+ help="Address of the pre-trained modeling",
+ )
+ parser.add_argument("--dataset", nargs="+", default=[])
+ parser.add_argument(
+ "--plugin",
+ type=str,
+ default="gemini",
+ choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
+ help="Choose which plugin to use",
+ )
+ parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
+ parser.add_argument("--save_interval", type=int, default=1000, help="Save interval")
+ parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
+ parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory")
+ parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
+ parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
+ parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process")
+ parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
+ parser.add_argument("--max_length", type=int, default=4096, help="Model max length")
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default="fp16",
+ choices=["fp16", "bf16"],
+ help="Mixed precision",
+ )
+ parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
+ parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
+ parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
+ parser.add_argument(
+ "--use_grad_checkpoint",
+ action="store_true",
+ default=False,
+ help="Use gradient checkpointing",
+ )
+ parser.add_argument(
+ "--use_flash_attn",
+ action="store_true",
+ default=False,
+ help="Use flash-attention",
+ )
+ parser.add_argument(
+ "--freeze_non_embeds_params",
+ action="store_true",
+ default=False,
+ help="Freeze non embeddings parameters",
+ )
+ parser.add_argument("--tp", type=int, default=1)
+ parser.add_argument("--zero", type=int, default=1)
+ args = parser.parse_args()
+
+ with open(args.config_file, "w") as f:
+ json.dump(args.__dict__, f, indent=4)
+
+ # ==============================
+ # Initialize Distributed Training
+ # ==============================
+ colossalai.launch_from_torch({})
+ coordinator = DistCoordinator()
+
+ # ==============================
+ # Initialize Tensorboard
+ # ==============================
+ if coordinator.is_master():
+ os.makedirs(args.tensorboard_dir, exist_ok=True)
+ writer = SummaryWriter(args.tensorboard_dir)
+
+ # ==============================
+ # Initialize Booster
+ # ==============================
+ if args.plugin == "gemini":
+ plugin = GeminiPlugin(
+ precision=args.mixed_precision,
+ initial_scale=2**16,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "gemini_auto":
+ plugin = GeminiPlugin(
+ precision=args.mixed_precision,
+ placement_policy="auto",
+ initial_scale=2**16,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "zero2":
+ plugin = LowLevelZeroPlugin(
+ stage=2,
+ precision=args.mixed_precision,
+ initial_scale=2**16,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "zero2_cpu":
+ plugin = LowLevelZeroPlugin(
+ stage=2,
+ precision=args.mixed_precision,
+ initial_scale=2**16,
+ cpu_offload=True,
+ max_norm=args.grad_clip,
+ )
+ elif args.plugin == "3d":
+ plugin = HybridParallelPlugin(
+ tp_size=args.tp,
+ pp_size=1,
+ zero_stage=args.zero,
+ max_norm=args.grad_clip,
+ precision=args.mixed_precision,
+ )
+ else:
+ raise ValueError(f"Unknown plugin {args.plugin}")
+
+ booster = Booster(plugin=plugin)
+
+ # ======================================================
+ # Initialize Tokenizer, Dataset, Collator and Dataloader
+ # ======================================================
+ tokenizer = LlamaTokenizer.from_pretrained(args.pretrained)
+ tokenizer.pad_token = tokenizer.unk_token
+ tokenizer.add_bos_token = False
+ tokenizer.add_eos_token = False
+
+ coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
+ coordinator.print_on_master(f"Tensorboard logs will be saved at: {args.tensorboard_dir}")
+ coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_dir}")
+
+ coordinator.print_on_master(f"Load dataset: {args.dataset}")
+
+ dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
+ dataloader = setup_distributed_dataloader(
+ dataset=dataset,
+ batch_size=args.micro_batch_size,
+ shuffle=True,
+ drop_last=True,
+ collate_fn=data_collator,
+ )
+ coordinator.print_on_master(
+ f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
+ )
+
+ # ======================================================
+ # Initialize Model, Objective, Optimizer and LR Scheduler
+ # ======================================================
+ init_ctx = (
+ LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
+ )
+ with init_ctx:
+ model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
+ # Freeze part of parameters.
+ if args.freeze_non_embeds_params:
+ freeze_non_embeds_parameters(model=model)
+
+ if args.use_grad_checkpoint:
+ model.gradient_checkpointing_enable()
+ coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
+ if args.use_flash_attn:
+ replace_with_flash_attention(model=model)
+ coordinator.print_on_master(msg="Flash-attention enabled successfully")
+
+ model_numel = get_model_numel(model)
+ coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
+
+ optimizer = HybridAdam(
+ model_params=filter(lambda p: p.requires_grad, model.parameters())
+ if args.freeze_non_embeds_params
+ else model.parameters(),
+ lr=args.lr,
+ betas=(0.9, 0.95),
+ weight_decay=args.weight_decay,
+ adamw_mode=True,
+ )
+
+ lr_scheduler = CosineAnnealingWarmupLR(
+ optimizer=optimizer,
+ total_steps=args.num_epochs * len(dataloader),
+ warmup_steps=args.warmup_steps
+ if args.warmup_steps is not None
+ else int(args.num_epochs * len(dataloader) * 0.025),
+ eta_min=0.1 * args.lr,
+ )
+
+ # Flash attention will be disabled because it does NOT support fp32.
+ default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
+ torch.set_default_dtype(default_dtype)
+ model, optimizer, _, dataloader, lr_scheduler = booster.boost(
+ model=model,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ dataloader=dataloader,
+ )
+
+ torch.set_default_dtype(torch.float)
+
+ if args.load_checkpoint is None:
+ coordinator.print_on_master(f"Load pretrained model checkpoint from {args.pretrained}")
+ booster.load_model(model, args.pretrained, strict=False)
+
+ coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
+ coordinator.print_on_master(
+ f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
+ )
+
+ start_epoch = 0
+ start_step = 0
+ sampler_start_idx = 0
+ if args.load_checkpoint is not None:
+ if "modeling" in args.load_checkpoint:
+ coordinator.print_on_master(f"Continued pretrain from checkpoint {args.load_checkpoint}")
+ booster.load_model(model, args.load_checkpoint)
+ else:
+ coordinator.print_on_master(f"Load model checkpoint from {args.load_checkpoint}")
+ start_epoch, start_step, sampler_start_idx = load_checkpoint(
+ load_dir=args.load_checkpoint,
+ booster=booster,
+ model=model,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ )
+ coordinator.print_on_master(
+ f"Loaded checkpoint {args.load_checkpoint} at epoch {start_epoch} step {start_step}"
+ )
+ coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
+
+ coordinator.print_on_master(
+ f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
+ )
+ coordinator.print_on_master(
+ f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
+ )
+ coordinator.print_on_master(
+ f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
+ )
+
+ num_steps_per_epoch = len(dataloader)
+ # If resume training, set the sampler start index to the correct value
+ assert isinstance(dataloader.sampler, StatefulDistributedSampler)
+ dataloader.sampler.set_start_index(start_index=sampler_start_idx)
+
+ for epoch in range(start_epoch, args.num_epochs):
+ dataloader.sampler.set_epoch(epoch=epoch)
+ with tqdm(
+ iterable=enumerate(dataloader, start=start_step),
+ desc=f"Epoch {epoch}",
+ disable=not coordinator.is_master(),
+ total=num_steps_per_epoch,
+ initial=start_step,
+ ) as pbar:
+ for step, batch in pbar:
+ batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
+
+ batch_output = model(**batch)
+
+ loss = batch_output.loss
+
+ booster.backward(loss=loss, optimizer=optimizer)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ all_reduce_mean(tensor=loss)
+ pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
+ if coordinator.is_master():
+ global_step = epoch * num_steps_per_epoch + step
+ writer.add_scalar(tag="Loss", scalar_value=loss.item(), global_step=global_step)
+ writer.add_scalar(
+ tag="Learning Rate",
+ scalar_value=lr_scheduler.get_last_lr()[0],
+ global_step=global_step,
+ )
+ # Save modeling.
+
+ if (args.save_interval > 0 and (step + 1) % args.save_interval == 0) or (step + 1) == len(dataloader):
+ coordinator.print_on_master("\nStart saving model checkpoint with running states")
+ save_checkpoint(
+ save_dir=args.save_dir,
+ booster=booster,
+ model=model,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ epoch=epoch,
+ step=step + 1,
+ batch_size=args.micro_batch_size,
+ coordinator=coordinator,
+ )
+ coordinator.print_on_master(
+ f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
+ )
+
+ # Delete CUDA cache.
+ # del batch, batch_labels, batch_output, loss
+ torch.cuda.empty_cache()
+
+ # the continue epochs are not resumed, so we need to reset the sampler start index and start step
+ dataloader.sampler.set_start_index(start_index=0)
+ start_step = 0
+
+ # Final save.
+ coordinator.print_on_master("Start saving final model checkpoint")
+ booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
+ coordinator.print_on_master(
+ f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}"
+ )
+
+ coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/applications/Colossal-LLaMA-2/version.txt b/applications/Colossal-LLaMA-2/version.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8a9ecc2ea99d607e92feae1656ddbf6fdd82a2c1
--- /dev/null
+++ b/applications/Colossal-LLaMA-2/version.txt
@@ -0,0 +1 @@
+0.0.1
\ No newline at end of file
diff --git a/applications/ColossalEval/README.md b/applications/ColossalEval/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..3f645fe7892c14fa5a3ff21cacb722df7baa82c1
--- /dev/null
+++ b/applications/ColossalEval/README.md
@@ -0,0 +1,560 @@
+
+
+
+
+
+
+## Table of Contents
+
+- [Overview](#overview)
+- [Leaderboard](#leaderboard)
+- [Install](#install)
+- [Evaluation Process](#evaluation-process)
+ - [Inference](#inference)
+ - [Dataset Preparation](#dataset-preparation)
+ - [Configuration](#configuration)
+ - [How to Use](#how-to-use)
+ - [Evaluation](#evaluation)
+ - [Dataset Evaluation](#dataset-evaluation)
+ - [Configuration](#dataset-evaluation)
+ - [How to Use](#dataset-evaluation)
+ - [GPT Evaluation](#gpt-evaluation)
+ - [Configuration](#gpt-evaluation)
+ - [How to Use](#gpt-evaluation)
+- [More Details](#more-details)
+ - [Inference Details](#inference-details)
+ - [Evaluation Details](#evaluation-details)
+ - [Metrics](#metrics)
+ - [examples](#examples)
+ - [Dataset Evaluation Example](#dataset-evaluation-example)
+ - [GPT Evaluation Example](#gpt-evaluation-example)
+- [To Do](#to-do)
+- [FAQ](#faq)
+ - [How to Add a New Metric?](#how-to-add-a-new-metric)
+ - [How to Add a New Dataset?](#how-to-add-a-new-dataset)
+ - [How to Add a New Model?](#how-to-add-a-new-model)
+- [Citations](#citations)
+
+## Overview
+[ColossalEval](https://github.com/hpcaitech/ColossalAI/tree/main/applications/ColossalEval) is a project which provides a uniform pipeline to help evaluate language models on different public dataset or your own dataset using both classic metrics and the help from GPTs. More details can be found in the following sections.
+
+## Leaderboard
+
+We conducted comprehensive evaluation on 4 dataset and compare our Colossal-Llama-2-7b-base model with various models.
+
+- We use 5-shot for MMLU and calculate scores based on the logits of first predicted token.
+- We use 5-shot for CMMLU and calculate scores based on the logits of first predicted token.
+- We use 5-shot for AGIEval and only calculate scores for 4-choice questions using a combination metric of exact match and the logits of first predicted token. If any of the exact match or logits of first predicted token is correct, the model will get the score.
+- We use 0-shot for GAOKAO-Bench and only calculate scores for 4-choice questions based on the logits of first predicted token.
+- The generation config for all dataset is greedy search.
+- We also provided CEval scores from its lastest leaderboard or the official repository of the model.
+
+More details about metrics can be found in [Metrics](#metrics).
+
+| | Backbone | Tokens Consumed | | MMLU | CMMLU | AGIEval | GAOKAO | CEval |
+| :----------------------------: | :--------: | :-------------: | :------------------: | :-----------: | :-----: | :----: | :----: | :----------------------------: |
+| | - | - | | 5-shot | 5-shot | 5-shot | 0-shot | 5-shot |
+| Baichuan-7B | - | 1.2T | | 42.32 (42.30) | 44.53 (44.02) | 38.72 | 36.74 | 42.80 |
+| Baichuan-13B-Base | - | 1.4T | | 50.51 (51.60) | 55.73 (55.30) | 47.20 | 51.41 | 53.60 |
+| Baichuan2-7B-Base | - | 2.6T | | 46.97 (54.16) | 57.67 (57.07) | 45.76 | 52.60 | 54.00 |
+| Baichuan2-13B-Base | - | 2.6T | | 54.84 (59.17) | 62.62 (61.97) | 52.08 | 58.25 | 58.10 |
+| ChatGLM-6B | - | 1.0T | | 39.67 (40.63) | 41.17 (-) | 40.10 | 36.53 | 38.90 |
+| ChatGLM2-6B | - | 1.4T | | 44.74 (45.46) | 49.40 (-) | 46.36 | 45.49 | 51.70 |
+| InternLM-7B | - | - | | 46.70 (51.00) | 52.00 (-) | 44.77 | 61.64 | 52.80 |
+| InternLM-20B | - | 2.3T | | 60.96 (62.05) | 59.08 (-) | 57.96 | 61.92 | - |
+| Qwen-7B (original) | - | 2.2T | | 54.29 (56.70) | 56.03 (58.80) | 52.47 | 56.42 | 59.60 |
+| Qwen-7B | - | 2.4T | | 58.33 (58.20) | 62.54 (62.20) | 64.34 | 74.05 | 63.50 |
+| | | | | | | | | |
+| Llama-2-7B | - | 2.0T | | 44.47 (45.30) | 32.97 (-) | 32.60 | 25.46 | - |
+| Linly-AI/Chinese-LLaMA-2-7B-hf | Llama-2-7B | 1.0T | | 37.43 | 29.92 | 32.00 | 27.57 | - |
+| wenge-research/yayi-7b-llama2 | Llama-2-7B | - | | 38.56 | 31.52 | 30.99 | 25.95 | - |
+| ziqingyang/chinese-llama-2-7b | Llama-2-7B | - | | 33.86 | 34.69 | 34.52 | 25.18 | 34.2 |
+| TigerResearch/tigerbot-7b-base | Llama-2-7B | 0.3T | | 43.73 | 42.04 | 37.64 | 30.61 | - |
+| LinkSoul/Chinese-Llama-2-7b | Llama-2-7B | - | | 48.41 | 38.31 | 38.45 | 27.72 | - |
+| FlagAlpha/Atom-7B | Llama-2-7B | 0.1T | | 49.96 | 41.10 | 39.83 | 33.00 | - |
+| IDEA-CCNL/Ziya-LLaMA-13B-v1.1 | Llama-13B | 0.11T | | 50.25 | 40.99 | 40.04 | 30.54 | - |
+| | | | | | | | | |
+| **Colossal-LLaMA-2-7b-base** | Llama-2-7B | **0.0085T** | | 53.06 | 49.89 | 51.48 | 58.82 | 50.20 |
+
+> The score in parentheses corresponds to the scores in the official repository of the model.
+>
+> We use zero-shot for ChatGLM models.
+>
+> To evaluate Qwen-7B on dataset MMLU, the prompt would be "xxx Answer:"(remove the space after ":") and we calculate the logits over " A", " B", " C" and " D" for Qwen-7B. Both the original and updated versions of Qwen-7B tend to be much more deterministic than other models. For example, the logits over " A" can be `-inf` and softmax would be exact `0`.
+>
+> For other models and other dataset, we calculate logits over "A", "B", "C" and "D".
+
+Our model achieves a much better score over all other Llama-1 or Llama-2 based models and also stands out among popular open source LLMs.
+
+## Install
+You should install `ColossalEval` in order to use it and `colossal_eval` is the package installed.
+```bash
+git clone https://github.com/hpcaitech/ColossalAI.git
+cd ColossalAI/applications/ColossalEval
+pip install .
+```
+If you want to add customized dataset or models, use `pip install -e .` in stead to ensure that any changes you make to the source code will immediately affect the package you install.
+
+## Evaluation Process
+The evaluation process involves 2 steps which are `inference` and `evaluation`. You need to set the config for each step.
+
+### Inference
+
+The inference process consists of two parts.
+1. Preprocess and convert the original dataset.
+2. Config your tokenizer and model arguments to perform zero-shot or few-shot prompting.
+
+#### Dataset Preparation
+
+In this step, the original dataset(either in `csv` or `jsonl` format) will be loaded and converted into a `dict`. In the conversion process, we carefully parse each subcategory and assign specific inference arguments for this subcategory.
+
+Inference arguments are stored in a `dict`. The following is an example.
+
+```python
+inference_kwargs = {
+ "calculate_loss": True,
+ "all_classes": ["A", "B", "C", "D"],
+ "language": "Chinese",
+ "pretrain": False,
+ "max_new_tokens": 32
+}
+```
+The `inference_kwargs` currently contains 5 fields:
+
+- `calculate_loss` (bool, compulsory): Whether the loss on target tokens will be calculated
+- `all_classes` (Optional[list], compulsory): Whether the subcategory is a single-choice question. Specify all available options in a list or otherwise None.
+- `language` (str, compulsory): The language for the subcategory.
+- `pretrain` (bool, compulsory): Whether the dataset is a pretrain dataset or not. It is usually used for calculate perplexity when you want to evaluate a model with extended context length.
+- `max_new_tokens` (int, compulsory): The number of new tokens to generate during inference.
+
+For example, for dataset MMLU, each subcategory consists of single-choice questions with options A, B, C and D by default and we can assign value `["A", "B", "C", "D"]` to key`all_classes`. For dataset C-Eval, target answers aren't provided in the test split so `calculate_loss` should be set as False. However, other dataset such as GAOKAO-bench contains different formats of questions and lacks some keys or metadata which can reveal what type (single-choice or multi-choice) of questions it is. Before assigning inference arguments, we first parse the dataset to decide which type of questions the subcategory belongs to and set the inference arguments accordingly.
+
+Other than `inference_kwargs`, `data` is a list containing questions of a same subcategory. The following is a converted dataset.
+
+```json
+{
+ "dev": {
+ "category 1": {"data": [], "inference_kwargs": {}},
+ "category 2": {"data": [], "inference_kwargs": {}}
+ },
+ "test": {
+ "category 1": {"data": [], "inference_kwargs": {}},
+ "category 2": {"data": [], "inference_kwargs": {}}
+ }
+}
+```
+
+A data sample basically follow the format of Alpaca. It should contain the following keys:
+
+* `dataset` (str, compulsory): The name of the dataset.
+* `split` (str, compulsory): The split of the instruction.
+* `catrgory` (str, compulsory): The category of the instruction.
+* `instruction` (str, compulsory): The instruction for the LLM.
+* `input` (str, optional): The additional context of the instruction.
+* `output` (str, optional): The model output of the instruction.
+* `target` (str, optional): The target answer for the instruction.
+
+Example:
+
+```json
+{
+ "dev": {
+ "Abstract Algebra": [
+ {
+ "dataset": "mmlu",
+ "split": "dev",
+ "category": "Abstract Algebra",
+ "instruction": "The following is a single-choice question on Abstract Algebra. Answer the question by replying A, B, C or D.",
+ "input": "Question: Find all c in Z_3 such that Z_3[x]/(x^2 + c) is a field.\nA. 0\nB. 1\nC. 2\nD. 3\nAnswer: ",
+ "output": "",
+ "target": "B"
+ },
+ ]
+ },
+ "test": {
+ "Abstract Algebra": [
+ {
+ "dataset": "mmlu",
+ "split": "test",
+ "category": "Abstract Algebra",
+ "instruction": "The following is a single-choice question on Abstract Algebra. Answer the question by replying A, B, C or D.",
+ "input": "Question: Find the degree for the given field extension Q(sqrt(2), sqrt(3), sqrt(18)) over Q.\nA. 0\nB. 4\nC. 2\nD. 6\nAnswer: ",
+ "output": "",
+ "target": "B"
+ },
+ ]
+ }
+}
+```
+
+#### Configuration
+In this step, you will configure your tokenizer and model arguments to infer on the given datasets.
+
+A config file consists of two parts.
+1. Model config. In model config, you need to specify model name, model path, model class, tokenizer arguments and model arguments. For model class, currently we support `HuggingFaceModel`, `HuggingFaceCausalLM`, `ChatGLMModel` and `ChatGLMModel2`. `HuggingFaceModel` is for models that can be loaded with `AutoModel` and `HuggingFaceCausalLM` is for models that can be loaded with `AutoModelForCausalLM`. `ChatGLMModel` and `ChatGLMModel2` are for ChatGLM and ChatGLM2 models respectively. You can check all model classes in `colossal_eval/models/__init__.py`. If your model should set `trust_remote_code` as true, specify it in the `tokenizer_kwargs` and `model_kwargs` fields.
+2. Dataset config. In dataset config, you need to specify dataset name, path and dataset class. Currently, we support zero-shot on dataset MMLU, CMMLU, AGIEval, GAOKAO-Bench and LongBench and few-shot on dataset MMLU, CMMLU and AGIEval. If you want to enable few shot, set `few_shot` as true. You can check all model classes in `colossal_eval/dataset/__init__.py`.
+
+Once you have all config ready, the program will run inference on all the given datasets on all the given models.
+
+An example config using model class `HuggingFaceCausalLM` and dataset class `CMMLUDataset` can be:
+```json
+{
+ "model": [
+ {
+ "name": "model name",
+ "model_class": "HuggingFaceCausalLM",
+ "parameters": {
+ "path": "path to model",
+ "model_max_length": 2048,
+ "tokenizer_path": "path to tokenizer",
+ "tokenizer_kwargs": {
+ "use_fast": false,
+ "trust_remote_code": true
+ },
+ "peft_path": null,
+ "model_kwargs": {
+ "trust_remote_code": true
+ },
+ "prompt_template": "plain",
+ "batch_size": 4
+ }
+ }
+ ],
+ "dataset": [
+ {
+ "name": "dataset name",
+ "dataset_class": "CMMLUDataset",
+ "debug": false,
+ "few_shot": true,
+ "path": "path to original dataset",
+ "save_path": "path to save converted dataset"
+ }
+ ]
+}
+```
+
+Currently, we support Hugging Face models. The `tokenizer_kwargs` is the arguments used in `AutoTokenizer.from_pretrained()`. The `model_kwargs` is the arguments used in `AutoModel.from_pretrained` or `AutoModelForCausalLM.from_pretrained()`. `few_shot` will be set true if you want to enable few-shot prompting for the dataset. `debug` will be set true if you want to verify whether your prompt is right or wrong.
+
+#### How to Use
+An example script can be the following. The `configs/dataset_evaluation/inference.py` is the same in all examples provided.
+
+```shell
+torchrun --nproc_per_node=1 inference.py \
+ --config "path to config file" \
+ --load_dataset \
+ --inference_save_path "path to save inference results"
+```
+
+You should specify the path to config file in `config`. You can run the script without specifying `load_dataset` if you already save the converted dataset or otherwise set it to first load the original dataset and save the converted dataset. You should specify the path to save inference results in `inference_save_path`.
+
+### Evaluation
+
+In the evaluation process, you only need to configure your evaluation parameters. You can use either public dataset or help from GPTs to do evaluation. We will introduce configuration for dataset evaluation and GPT evaluation.
+
+#### Dataset Evaluation
+
+In dataset evaluation, we calculate different metrics on the given inference results and public dataset.
+
+##### Configuration
+
+A config file for dataset evaluation consists of two parts.
+1. Model config. In model config, you need to specify model name. If you want to evaluate perplexity over a pretrain dataset and calculate per-byte-perplexity, you have to add your tokenizer config and model max length.
+2. Dataset config. In dataset config, you need to specify the evaluation metrics for the dataset.
+
+Once you have all config ready, the program will run evaluation on inference results for all given models and dataset.
+
+An example config can be:
+```json
+{
+ "model": [
+ {
+ "name": "model name"
+ }
+ ],
+ "dataset": [
+ {
+ "name": "dataset name",
+ "metrics": ["first_token_accuracy"]
+ }
+ ]
+}
+```
+
+The above config specifies that the program will evaluate the inference results using `first_token_accuracy` metric.
+
+##### How to Use
+
+An example script can be the following.
+
+```shell
+python eval_dataset.py \
+ --config "path to config file" \
+ --inference_results_path "path to inference results" \
+ --evaluation_results_save_path "path to save evaluation results"
+```
+
+You should specify the path to config file in `config`, the path to inference results in `inference_results_path` and the path to save evaluation results in `evaluation_save_path`.
+
+#### GPT Evaluation
+
+In GPT evaluation, we provide a prompt template which can fit in different pre-defined metrics with Chain-of-Thoughts. In the following sections, we will only introduce how you can evaluate model answers using GPTs. More details can be found in `colossal_eval/evaluate/GPT Evaluation.md`.
+
+##### Configuration
+
+The following is an example of a English config file. The configuration file can control how the pipeline evaluates the model. You need to specify GPT evaluation metrics. You can find an example English config file in `configs/gpt_evaluation`.
+
+```json
+{
+ "language": "en",
+ "category": {
+ "brainstorming": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "creativity",
+ "practicality",
+ "reasonableness"
+ ]
+ },
+ }
+}
+```
+
+##### How to Use
+After setting the config file, you can evaluate the model using `examples/gpt_evaluation/eval.py`. If you want to make comparisons between answers of two different models, you should specify two answer files in the argument `answer_file_list` and two model names in the argument `model_name_list`(details can be found in `colossal_eval/evaluate/GPT Evaluation.md`). If you want to evaluate one answer file, the length of both `answer_file_list` and `model_name_list` should be 1 and the program will perform evaluation using GPTs. The prompt files for battle and gpt evaluation can be found in `configs/gpt_evaluation/prompt`. `target file` is the path to the converted dataset you save during inference time.
+
+An example script is provided as follows:
+
+```shell
+python eval.py \
+ --config_file "path to the config file" \
+ --battle_prompt_file "path to the prompt file for battle" \
+ --gpt_evaluation_prompt_file "path to the prompt file for gpt evaluation" \
+ --target_file "path to the target answer file" \
+ --answer_file_list "path to the answer file" \
+ --model_name_list "the names of the model" \
+ --gpt_model "which GPT model to use for evaluation" \
+ --save_path "path to save results" \
+ --openai_key "your openai key" \
+```
+
+## More Details
+
+### Inference
+
+In the inference process, we will do generation, calculate loss over target tokens, calculate number of target tokens, softmax over given options (for example, "A", "B", "C", and "D") according to the inference arguments.
+
+For tokenization, we adopt tokenization strategy in [LongBench](https://github.com/THUDM/LongBench/blob/main/pred.py#L55) to preserve crucial instructions on the left and right side and keep all target tokens.
+
+For labeling target tokens, we adopt method from [FastChat](https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L137), but it doesn't always hold true due to tokenizers' different behavior. We plan to insert special tokens to correctly label the target tokens.
+
+For calculating loss, we return per-sample-loss instead of per-batch-loss if we directly use `model(batch).loss` provided in HuggingFace.
+
+### Evaluation
+
+To make it more easier to set the config, you only need to specify all metrics you want to use in key `metrics`. However, the program will only use a subset of metrics you give for different subcategories. Applying all metrics to all subcategories is obviously unsuitable. The suggested metrics for specific categories should be defined in `colossal_eval/evaluate/dataset_evaluator/metrics.py`.
+
+#### Metrics
+
+- `combined_single_choice_accuracy`: A combination of `first_token_logit` and `single_choice_accuracy`. If one of these is correct, the model will get the score. It can be used in all dataset that contains single-choice questions.
+- `first_token_logit`: Calculate score based on softmax score over the given choices. If the argmax of the softmax is equal to the reference, the model will get the score. If there is `NaN` in softmax score, it will calculate the score using exact match. It can be used in all dataset that contains single-choice questions.
+- `single_choice_accuracy`: Calculate score using exact match. It will only get the first uppercase letter such as A, B, C or D that is not surrouded by lowercase letters. If the uppercase letter is equal to the reference, the model will get the score. It can be used in all dataset that contains single-choice questions.
+- `multi_choice_accuracy`: Calculate score on multi-choice questions. It will get a set of all uppercase letters such as A, B, C or D that is not surrouded by lowercase letters. If the prediction conatains uppercase letters that are not in reference. The model will get 0 score. If the prediction contains a uppercase letter that is in reference, the model will get a score of `1/len(reference)`. It is used in AGIEval and GAOKAO-Bench.
+- `math_equivalence`: Code from [hendrycks](https://github.com/hendrycks/math/blob/main/modeling/math_equivalence.py). Compute scores over the prediction math formula and reference math formula. It is used in AGIEval and GAOKAO-Bench.
+- `f1_score`: Calculate English f1 score between prediction and reference. It is used in Longbench.
+- `f1_zh_score`: Calculate Chinese f1 score between prediction and reference. It is used in Longbench.
+- `rouge_score`: Calculate English f1 score between prediction and reference. It is used in GAOKAO-Bench and LongBench.
+- `rouge_zh_score`: Calculate Chinese rouge score between prediction and reference. It is used in GAOKAO-Bench and LongBench.
+- `retrieval_score`: Calculate English retrieval score between prediction and reference. It determines whether the ouput(which paragraph) corresponds to the given abstract. It is used in Longbench.
+- `retrieval_zh_score`: Calculate Chinese retrieval score between prediction and reference. It determines whether the ouput(which paragraph) corresponds to the given abstract. It is used in Longbench.
+- `classification_score`: Calculate classification score between prediction and reference. It determines whether the ouput(a class) is equal to the reference. It is used in Longbench.
+- `code_sim_score`: Calculate similarity score between prediction and reference. It is used in Longbench.
+- `count_score`: Calculate count score between prediction and reference. It determines whether the ouput(number of given passages) is equal to the reference. It is used in Longbench.
+- `perplexity`: Calculate perplexity. The formula is $ perplexity = \frac{1}{n} \sum_i e^{loss_i} $ where $n$ is the number of samples and $ loss_i $ is the average loss for sample $ i $. It can be used in all dataset.
+- `ppl_score`: Calculate perplexity score. The formula is $ ppl\_score = \frac{1}{n} \sum_i e^{-loss_i} $ where $n$ is the number of samples and $ loss_i $ is the average loss for sample $ i $. It can be used in all dataset.
+- `ppl_score_over_choices`: Calculate perplexity score over choices. The formula is $ ppl\_score\_over\_choices= \frac{1}{n} \sum_i e^{-loss\_over\_choices_i} $ where $n$ is the number of samples and $ loss\_over\_choices_i $ is the loss on the first predicted token for sample $ i $. It can be used in all dataset that contains single-choice questions.
+- `per_byte_perplexity`: Calculate per byte perplexity. The formula is $ \frac{1}{n} \sum_i e^{\frac{loss_i}{byte_i}} $ where $n$ is the number of samples, $ loss_i $ is the total loss for sample $ i $ and $ byte_i $ is the number of bytes sample $ i $ occupies. It can be used in all dataset.
+- `per_byte_ppl_score`: Calculate per byte perplexity score. The formula is $ \frac{1}{n} \sum_i e^{-\frac{loss_i}{byte_i}} $ where $n$ is the number of samples, $ loss_i $ is the total loss for sample $ i $ and $ byte_i $ is the number of bytes sample $ i $ occupies. It can be used in all dataset.
+
+We use `combined_single_choice_accuracy` and `first_token_logit` in the leaderboard.
+
+### Examples
+
+We provide 2 examples for you to explore our `colossal_eval` package.
+
+#### Dataset Evaluation Example
+
+This example is in folder `examples/dataset_evaluation`.
+
+1. `cd examples/dataset_evaluation`
+2. Fill in your inference config file in `config/inference/config.json`. Set the model and dataset parameters.
+3. Run `inference.sh` to get inference results.
+4. Fill in your evaluation config file in `config/evaluation/config.json`. Set the model and dataset parameters.
+5. Run `eval_dataset.sh` to get evaluation results.
+
+#### GPT Evaluation Example
+
+The examples is in folder `examples/gpt_evaluation`.
+
+1. `cd examples/gpt_evaluation`
+2. Fill in your inference config file in `config/inference/config.json`. Set the model and dataset parameters. If you want to use the example dataset we provide, the dataset is `ColossalDataset`.
+3. Run `inference.sh` to get inference results.
+4. Fill in your evaluation config file in `config/evaluation/config.json`.
+5. Run `eval.sh` to get evaluation results.
+
+## FAQ
+
+### How to Add a New Metric?
+
+If you want to add a customized metric, we recommend using `pip install -e .` to ensure that any changes you make to the source code will immediately affect the package you install.
+
+To add a new metric, you can follow the example of multi_choice_accuracy in line 339 in `colossal_eval/evaluate/dataset_evaluator/metric.py`. The method take one data sample's prediction and reference as input and return a score ranging from 0 to 1.
+
+A skeleton of code is the following.
+
+```python
+
+def CustomizedMetric(prediction: str, reference: str):
+ score = xxx
+ return score
+```
+
+Once you have successfully added your own metric, you should specify your metric both in `colossal_eval/evaluate/dataset_evaluator/metric.py` (suggest which subcategories shoule the metric be applied to) and your evaluation config.
+
+### How to Add a New Dataset?
+
+If you want to add customized dataset, we recommend using `pip install -e .` to ensure that any changes you make to the source code will immediately affect the package you install.
+
+To add a new dataset, you can follow the example of `colossal_eval/dataset/mmlu.py`. You need to make sure that the format of questions in one subcategory should be the same. For example, all questions should have target answers or all questions should be single-choice questions.
+
+A skeleton of code is the following.
+
+```python
+
+class CustomizedDataset(BaseDataset):
+ @staticmethod
+ def load():
+ # 1. Load and convert the original dataset format.
+ # 2. Assign inference arguments for each subcategory.
+ # 3. Return the converted dataset.
+ pass
+```
+
+Once you have successfully added your own dataset, you can specify your dataset class in your inference config.
+
+### How to Add a New Model?
+
+If you want to add customized models, we recommend using `pip install -e .` to ensure that any changes you make to the source code will immediately affect the package you install.
+
+To add a new model, you can follow the example of `colossal_eval/models/huggingface.py`. You need to provide a way to load the model and tokenizer, calculate loss and generate.
+
+A skeleton of code is the following.
+
+```python
+
+class CustomizedModel(BaseModel):
+ def __init__(self):
+ super().__init__()
+ self._load_tokenizer()
+ self._load_model()
+
+ def _load_tokenizer():
+ pass
+
+ def _load_model():
+ pass
+
+ def _calculate_loss():
+ pass
+
+ def get_loss():
+ self._calculate_loss()
+
+ def inference(samples):
+ # 1. Load samples from the same subcategory.
+ # 2. Infer in a batch way according to inference arguments.
+ # 3. Return results.
+ batch_samples = xxx
+ self.get_loss(batch_samples)
+ self.generate(batch_samples)
+
+ return inference_results
+
+ def generate():
+ pass
+```
+
+Once you have successfully added your own model, you can specify your model class in your inference config.
+
+## To do
+
+- [ ] Add visualization code for evaluation results on public dataset
+- [ ] Improve the way to label target tokens
+
+## Citations
+
+```bibtex
+@misc{zhong2023agieval,
+ title={AGIEval: A Human-Centric Benchmark for Evaluating Foundation Models},
+ author={Wanjun Zhong and Ruixiang Cui and Yiduo Guo and Yaobo Liang and Shuai Lu and Yanlin Wang and Amin Saied and Weizhu Chen and Nan Duan},
+ year={2023},
+ eprint={2304.06364},
+ archivePrefix={arXiv},
+ primaryClass={cs.CL}
+}
+
+@article{huang2023ceval,
+title={C-Eval: A Multi-Level Multi-Discipline Chinese Evaluation Suite for Foundation Models},
+author={Huang, Yuzhen and Bai, Yuzhuo and Zhu, Zhihao and Zhang, Junlei and Zhang, Jinghan and Su, Tangjun and Liu, Junteng and Lv, Chuancheng and Zhang, Yikai and Lei, Jiayi and Fu, Yao and Sun, Maosong and He, Junxian},
+journal={arXiv preprint arXiv:2305.08322},
+year={2023}
+}
+
+@misc{li2023cmmlu,
+ title={CMMLU: Measuring massive multitask language understanding in Chinese},
+ author={Haonan Li and Yixuan Zhang and Fajri Koto and Yifei Yang and Hai Zhao and Yeyun Gong and Nan Duan and Timothy Baldwin},
+ year={2023},
+ eprint={2306.09212},
+ archivePrefix={arXiv},
+ primaryClass={cs.CL}
+}
+
+@inproceedings{Zhang2023EvaluatingTP,
+ title={Evaluating the Performance of Large Language Models on GAOKAO Benchmark},
+ author={Xiaotian Zhang and Chunyang Li and Yi Zong and Zhengyu Ying and Liang He and Xipeng Qiu},
+ year={2023}
+}
+
+@misc{bai2023longbench,
+ title={LongBench: A Bilingual, Multitask Benchmark for Long Context Understanding},
+ author={Yushi Bai and Xin Lv and Jiajie Zhang and Hongchang Lyu and Jiankai Tang and Zhidian Huang and Zhengxiao Du and Xiao Liu and Aohan Zeng and Lei Hou and Yuxiao Dong and Jie Tang and Juanzi Li},
+ year={2023},
+ eprint={2308.14508},
+ archivePrefix={arXiv},
+ primaryClass={cs.CL}
+}
+
+@article{hendryckstest2021,
+ title={Measuring Massive Multitask Language Understanding},
+ author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt},
+ journal={Proceedings of the International Conference on Learning Representations (ICLR)},
+ year={2021}
+}
+
+@article{hendrycks2021ethics,
+ title={Aligning AI With Shared Human Values},
+ author={Dan Hendrycks and Collin Burns and Steven Basart and Andrew Critch and Jerry Li and Dawn Song and Jacob Steinhardt},
+ journal={Proceedings of the International Conference on Learning Representations (ICLR)},
+ year={2021}
+}
+
+@misc{zheng2023judging,
+ title={Judging LLM-as-a-judge with MT-Bench and Chatbot Arena},
+ author={Lianmin Zheng and Wei-Lin Chiang and Ying Sheng and Siyuan Zhuang and Zhanghao Wu and Yonghao Zhuang and Zi Lin and Zhuohan Li and Dacheng Li and Eric. P Xing and Hao Zhang and Joseph E. Gonzalez and Ion Stoica},
+ year={2023},
+ eprint={2306.05685},
+ archivePrefix={arXiv},
+ primaryClass={cs.CL}
+}
+
+```
diff --git a/applications/Chat/coati/ray/src/__init__.py b/applications/ColossalEval/colossal_eval/__init__.py
similarity index 100%
rename from applications/Chat/coati/ray/src/__init__.py
rename to applications/ColossalEval/colossal_eval/__init__.py
diff --git a/applications/ColossalEval/colossal_eval/dataset/__init__.py b/applications/ColossalEval/colossal_eval/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ea173198f5a84d053a835c885403e19d42e70a5
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/dataset/__init__.py
@@ -0,0 +1,19 @@
+from .agieval import AGIEvalDataset
+from .base import BaseDataset
+from .ceval import CEvalDataset
+from .cmmlu import CMMLUDataset
+from .colossalai import ColossalDataset
+from .gaokaobench import GaoKaoBenchDataset
+from .longbench import LongBenchDataset
+from .mmlu import MMLUDataset
+
+__all__ = [
+ "AGIEvalDataset",
+ "BaseDataset",
+ "CEvalDataset",
+ "CMMLUDataset",
+ "GaoKaoBenchDataset",
+ "LongBenchDataset",
+ "MMLUDataset",
+ "ColossalDataset",
+]
diff --git a/applications/ColossalEval/colossal_eval/dataset/agieval.py b/applications/ColossalEval/colossal_eval/dataset/agieval.py
new file mode 100644
index 0000000000000000000000000000000000000000..92ebd65931edf829a35d0535abb52f2088f34108
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/dataset/agieval.py
@@ -0,0 +1,247 @@
+# Adapted from https://github.com/ruixiangcui/AGIEval/blob/main/src/dataset_loader.py.
+
+import ast
+import glob
+import os
+from copy import deepcopy
+from typing import Dict, List
+
+import pandas as pd
+from colossal_eval.utils import get_json_list
+
+from colossalai.logging import DistributedLogger
+
+from .base import BaseDataset
+
+# define the datasets
+english_qa_datasets = [
+ "lsat-ar",
+ "lsat-lr",
+ "lsat-rc",
+ "logiqa-en",
+ "sat-math",
+ "sat-en",
+ "aqua-rat",
+ "sat-en-without-passage",
+ "gaokao-english",
+]
+chinese_qa_datasets = [
+ "logiqa-zh",
+ "jec-qa-kd",
+ "jec-qa-ca",
+ "gaokao-chinese",
+ "gaokao-geography",
+ "gaokao-history",
+ "gaokao-biology",
+ "gaokao-chemistry",
+ "gaokao-physics",
+ "gaokao-mathqa",
+]
+english_cloze_datasets = ["math"]
+chinese_cloze_datasets = ["gaokao-mathcloze"]
+
+multi_choice_datasets = ["jec-qa-kd", "jec-qa-ca", "gaokao-physics", "gaokao-mathqa"]
+math_output_datasets = {"gaokao-mathcloze", "math"}
+
+default_inference_kwargs = {
+ "calculate_loss": True,
+ "all_classes": None,
+ "language": "Chinese",
+ "pretrain": False,
+ "max_new_tokens": 32,
+}
+
+
+def get_prompt(line: Dict, dataset_name: str, logger: DistributedLogger) -> Dict:
+ """Modified from https://github.com/microsoft/AGIEval/blob/main/src/dataset_loader.py#L190"""
+ try:
+ all_classes = None
+ passage = line["passage"] if line["passage"] is not None else ""
+
+ if dataset_name in english_qa_datasets:
+ option_string = "ABCDEFG"
+ count = len(line["options"])
+
+ input = (
+ "Question: "
+ + line["question"]
+ + " "
+ + "Choose from the following options: "
+ + " ".join(line["options"])
+ + "\n"
+ + "Answer: "
+ )
+
+ all_classes = list(option_string[0:count])
+
+ elif dataset_name in chinese_qa_datasets:
+ option_string = "ABCDEFG"
+ count = len(line["options"])
+
+ input = "问题:" + line["question"] + " " + "从以下选项中选择:" + " ".join(line["options"]) + "\n" + "答案:"
+
+ all_classes = list(option_string[0:count])
+
+ elif dataset_name in english_cloze_datasets:
+ input = "Question: " + line["question"] + "\n" + "Answer: "
+
+ elif dataset_name in chinese_cloze_datasets:
+ input = "问题:" + line["question"] + "\n" + "答案:"
+
+ return {
+ "instruction": input if not passage else passage + "\n\n" + input,
+ "target": line["label"] if line["label"] else line["answer"],
+ }, all_classes
+
+ except NameError:
+ logger.info("Dataset not defined.")
+
+
+# process few-shot raw_prompts
+def combine_prompt(prompt_path, dataset_name, load_explanation=True, chat_mode=False):
+ skip_passage = False
+ if dataset_name == "sat-en-without-passage":
+ skip_passage = True
+ dataset_name = "sat-en"
+ demostrations = []
+ # read the prompts by context and explanation
+ context_row = [0, 1, 3, 5, 7, 9]
+ explanation_row = [0, 2, 4, 6, 8, 10]
+ raw_prompts_context = pd.read_csv(
+ prompt_path, header=0, skiprows=lambda x: x not in context_row, keep_default_na=False
+ )
+ raw_prompts_explanation = pd.read_csv(
+ prompt_path, header=0, skiprows=lambda x: x not in explanation_row, keep_default_na=False
+ ).replace(r"\n\n", "\n", regex=True)
+ contexts = []
+ for line in list(raw_prompts_context[dataset_name]):
+ if line:
+ # print(line)
+ contexts.append(ast.literal_eval(line))
+ explanations = [exp for exp in raw_prompts_explanation[dataset_name] if exp]
+
+ for idx, (con, exp) in enumerate(zip(contexts, explanations)):
+ passage = con["passage"] if con["passage"] is not None and not skip_passage else ""
+ question = con["question"]
+ options = con["options"] if con["options"] is not None else ""
+ label = con["label"] if con["label"] is not None else ""
+ answer = con["answer"] if "answer" in con and con["answer"] is not None else ""
+
+ if dataset_name in english_qa_datasets:
+ question_input = (
+ "Question: "
+ + passage
+ + " "
+ + question
+ + "\n"
+ + "Choose from the following options: "
+ + " ".join(options)
+ + "\n"
+ + "Answer: {}".format(label)
+ )
+ elif dataset_name in chinese_qa_datasets:
+ question_input = (
+ "问题:" + passage + " " + question + "\n" + "从以下选项中选择:" + " ".join(options) + "\n" + "答案:{}".format(label)
+ )
+ elif dataset_name in english_cloze_datasets:
+ question_input = "Question: ".format(idx + 1) + question + "\n" + "Answer: {}".format(answer)
+ elif dataset_name in chinese_cloze_datasets:
+ question_input = "问题:" + question + "\n" + "答案:{}".format(answer)
+ else:
+ raise ValueError(f"During loading few-sot examples, found unknown dataset: {dataset_name}")
+
+ if chat_mode:
+ demostrations.append((question_input,))
+ else:
+ demostrations.append(question_input + "\n")
+
+ return demostrations
+
+
+class AGIEvalDataset(BaseDataset):
+ """
+ Dataset wrapper for AGIEval dataset.
+ Data source: https://github.com/microsoft/AGIEval
+ This dataset class will convert the original dataset into the inference dataset.
+
+ A few dirty data needed to be manually corrected in the origin dataset:
+ Issue link: https://github.com/microsoft/AGIEval/issues/16
+ 1. Invalid options in line 190 in gaokao-chemistry.jsonl.
+ 2. Option D (They may increase in value as those same resources become rare on Earth.) missing in line 17 in sat-en-without-passage.jsonl.
+ 3. Option D (They may increase in value as those same resources become rare on Earth.) missing in line 17 in sat-en.jsonl.
+ 4. Option D (No, because the data do not indicate whether the honeybees had been infected with mites.) missing in line 57 in sat-en-without-passage.jsonl.
+ 5. Option D (No, because the data do not indicate whether the honeybees had been infected with mites.) missing in line 57 in sat-en.jsonl.
+ 6. Option D (Published theories of scientists who developed earlier models of the Venus flytrap) missing in line 98 in sat-en-without-passage.jsonl.
+ 7. Option D (Published theories of scientists who developed earlier models of the Venus flytrap) missing in line 98 in sat-en.jsonl.
+ 8. Label is empty in line 212 in jec-qa-kd.jsonl. Content is also dirty.
+ 9. Actually, gaokao-mathqa.jsonl is also a multi-choice dataset. See line 149 286 287.
+ """
+
+ @staticmethod
+ def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
+ dataset = {"test": {}}
+
+ files = glob.glob(os.path.join(path, "*.jsonl"))
+ files.sort()
+
+ if few_shot:
+ prompt_path = os.path.join(path, "few_shot_prompts.csv")
+
+ for file in files:
+ dataset_name = os.path.basename(file)[0 : -len(".jsonl")]
+
+ few_shot_data = []
+ if few_shot:
+ # process demo once if it is few-shot-CoT
+ few_shot_data = combine_prompt(prompt_path, dataset_name, load_explanation=False, chat_mode=False)
+
+ dataset["test"][dataset_name] = {"data": []}
+
+ file_dir = os.path.join(path, file)
+
+ loaded_jsonl = get_json_list(file_dir)
+
+ # It's been tested that each data sample in one subcategory have same inference arguments.
+ _, all_classes = get_prompt(loaded_jsonl[0], dataset_name, logger)
+ inference_kwargs = deepcopy(default_inference_kwargs)
+ if all_classes is not None and dataset_name not in multi_choice_datasets:
+ inference_kwargs["all_classes"] = all_classes
+
+ if dataset_name in english_qa_datasets:
+ inference_kwargs["language"] = "English"
+ if dataset_name in chinese_qa_datasets:
+ inference_kwargs["language"] = "Chinese"
+ inference_kwargs["few_shot_data"] = few_shot_data
+
+ dataset["test"][dataset_name]["inference_kwargs"] = inference_kwargs
+
+ for line in loaded_jsonl:
+ info, all_classes = get_prompt(line, dataset_name, logger)
+
+ # Convert multi-choice answers to a single string.
+ # We will convert it back when evaluating.
+ # We do this because if target is a list, it should be only used for multiple target answers.
+ if dataset_name in multi_choice_datasets:
+ if isinstance(info["target"], str) and len(info["target"]) > 1:
+ # "gaokao-mathqa" actually contain multi-choice questions.
+ # This if clause is specially used for it.
+ info["target"] = "".join(info["target"].split())
+ else:
+ info["target"] = "".join(info["target"])
+
+ if isinstance(info["target"], list) and len(info["target"]) == 1:
+ info["target"] = info["target"][0]
+
+ data_sample = {
+ "dataset": "agieval",
+ "split": "test",
+ "category": dataset_name,
+ "instruction": info["instruction"],
+ "input": "",
+ "output": "",
+ "target": info["target"],
+ }
+
+ dataset["test"][dataset_name]["data"].append(data_sample)
+
+ return dataset
diff --git a/applications/ColossalEval/colossal_eval/dataset/base.py b/applications/ColossalEval/colossal_eval/dataset/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..45b0151b849f038590521130493207eaaff47543
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/dataset/base.py
@@ -0,0 +1,24 @@
+from abc import abstractstaticmethod
+
+from colossal_eval.utils import jdump
+
+
+class BaseDataset:
+ """
+ Base class for dataset wrapper.
+
+ Args:
+ path: The path to the original dataset.
+ logger: Logger for the dataset.
+ """
+
+ def __init__(self, path, logger, few_shot):
+ self.dataset = self.load(path, logger, few_shot)
+
+ def save(self, save_path):
+ """Save the converted dataset"""
+ jdump(self.dataset, save_path)
+
+ @abstractstaticmethod
+ def load(path, logger):
+ """Load the original dataset and convert it into the inference dataset"""
diff --git a/applications/ColossalEval/colossal_eval/dataset/ceval.py b/applications/ColossalEval/colossal_eval/dataset/ceval.py
new file mode 100644
index 0000000000000000000000000000000000000000..32ec52087bd3f2986f2bcb80e2a590e459fb84f8
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/dataset/ceval.py
@@ -0,0 +1,132 @@
+import copy
+import csv
+import os
+from typing import Dict, List
+
+from colossalai.logging import DistributedLogger
+
+from .base import BaseDataset
+
+ceval_subject_mapping = {
+ "computer_network": ["Computer Network", "计算机网络", "STEM"],
+ "operating_system": ["Operating System", "操作系统", "STEM"],
+ "computer_architecture": ["Computer Architecture", "计算机组成", "STEM"],
+ "college_programming": ["College Programming", "大学编程", "STEM"],
+ "college_physics": ["College Physics", "大学物理", "STEM"],
+ "college_chemistry": ["College Chemistry", "大学化学", "STEM"],
+ "advanced_mathematics": ["Advanced Mathematics", "高等数学", "STEM"],
+ "probability_and_statistics": ["Probability and Statistics", "概率统计", "STEM"],
+ "discrete_mathematics": ["Discrete Mathematics", "离散数学", "STEM"],
+ "electrical_engineer": ["Electrical Engineer", "注册电气工程师", "STEM"],
+ "metrology_engineer": ["Metrology Engineer", "注册计量师", "STEM"],
+ "high_school_mathematics": ["High School Mathematics", "高中数学", "STEM"],
+ "high_school_physics": ["High School Physics", "高中物理", "STEM"],
+ "high_school_chemistry": ["High School Chemistry", "高中化学", "STEM"],
+ "high_school_biology": ["High School Biology", "高中生物", "STEM"],
+ "middle_school_mathematics": ["Middle School Mathematics", "初中数学", "STEM"],
+ "middle_school_biology": ["Middle School Biology", "初中生物", "STEM"],
+ "middle_school_physics": ["Middle School Physics", "初中物理", "STEM"],
+ "middle_school_chemistry": ["Middle School Chemistry", "初中化学", "STEM"],
+ "veterinary_medicine": ["Veterinary Medicine", "兽医学", "STEM"],
+ "college_economics": ["College Economics", "大学经济学", "Social Science"],
+ "business_administration": ["Business Administration", "工商管理", "Social Science"],
+ "marxism": ["Marxism", "马克思主义基本原理", "Social Science"],
+ "mao_zedong_thought": ["Mao Zedong Thought", "毛泽东思想和中国特色社会主义理论体系概论", "Social Science"],
+ "education_science": ["Education Science", "教育学", "Social Science"],
+ "teacher_qualification": ["Teacher Qualification", "教师资格", "Social Science"],
+ "high_school_politics": ["High School Politics", "高中政治", "Social Science"],
+ "high_school_geography": ["High School Geography", "高中地理", "Social Science"],
+ "middle_school_politics": ["Middle School Politics", "初中政治", "Social Science"],
+ "middle_school_geography": ["Middle School Geography", "初中地理", "Social Science"],
+ "modern_chinese_history": ["Modern Chinese History", "近代史纲要", "Humanities"],
+ "ideological_and_moral_cultivation": ["Ideological and Moral Cultivation", "思想道德修养与法律基础", "Humanities"],
+ "logic": ["Logic", "逻辑学", "Humanities"],
+ "law": ["Law", "法学", "Humanities"],
+ "chinese_language_and_literature": ["Chinese Language and Literature", "中国语言文学", "Humanities"],
+ "art_studies": ["Art Studies", "艺术学", "Humanities"],
+ "professional_tour_guide": ["Professional Tour Guide", "导游资格", "Humanities"],
+ "legal_professional": ["Legal Professional", "法律职业资格", "Humanities"],
+ "high_school_chinese": ["High School Chinese", "高中语文", "Humanities"],
+ "high_school_history": ["High School History", "高中历史", "Humanities"],
+ "middle_school_history": ["Middle School History", "初中历史", "Humanities"],
+ "civil_servant": ["Civil Servant", "公务员", "Other"],
+ "sports_science": ["Sports Science", "体育学", "Other"],
+ "plant_protection": ["Plant Protection", "植物保护", "Other"],
+ "basic_medicine": ["Basic Medicine", "基础医学", "Other"],
+ "clinical_medicine": ["Clinical Medicine", "临床医学", "Other"],
+ "urban_and_rural_planner": ["Urban and Rural Planner", "注册城乡规划师", "Other"],
+ "accountant": ["Accountant", "注册会计师", "Other"],
+ "fire_engineer": ["Fire Engineer", "注册消防工程师", "Other"],
+ "environmental_impact_assessment_engineer": ["Environmental Impact Assessment Engineer", "环境影响评价工程师", "Other"],
+ "tax_accountant": ["Tax Accountant", "税务师", "Other"],
+ "physician": ["Physician", "医师资格", "Other"],
+}
+
+default_inference_kwargs = {
+ "calculate_loss": False,
+ "all_classes": ["A", "B", "C", "D"],
+ "language": "Chinese",
+ "pretrain": False,
+ "max_new_tokens": 32,
+}
+
+
+def get_few_shot_data(data: List[Dict]):
+ few_shot_data = []
+ for i in data:
+ few_shot_data.append(i["input"] + i["target"])
+ return few_shot_data
+
+
+class CEvalDataset(BaseDataset):
+ """
+ Dataset class for CEval dataset.
+ Data source: https://huggingface.co/datasets/ceval/ceval-exam
+ This dataset class will convert the original dataset into the inference dataset.
+ """
+
+ @staticmethod
+ def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
+ dataset = {"dev": {}, "test": {}}
+ for split in ["dev", "test"]:
+ files = os.listdir(os.path.join(path, split))
+ files.sort()
+
+ for file in files:
+ subject = file[0 : -len(f"_{split}.csv")]
+ subject = ceval_subject_mapping[subject][1]
+
+ file_dir = os.path.join(path, split, file)
+
+ dataset[split][subject] = {"data": []}
+
+ # It's been tested that each data sample in one subcategory have same inference arguments.
+ dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs)
+
+ if split == "test" and few_shot:
+ dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data(
+ dataset["dev"][subject]["data"]
+ )
+
+ with open(file_dir, encoding="utf-8") as f:
+ reader = csv.reader(f)
+ _ = next(reader)
+ for row in reader:
+ # Dev split have answer and explanation so len(row) is 8
+ # But test split doesn't contain answer and explanation, so len(row) is 6
+ assert len(row) >= 6
+ choices = f"A. {row[2]}\nB. {row[3]}\nC. {row[4]}\nD. {row[5]}"
+ data_sample = {
+ "dataset": "ceval",
+ "split": split,
+ "category": subject,
+ "instruction": f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。",
+ "input": f"题目:{row[1]}\n{choices}\n答案:",
+ "output": "",
+ "target": row[6] if split == "dev" else "",
+ "id": int(row[0]),
+ }
+
+ dataset[split][subject]["data"].append(data_sample)
+
+ return dataset
diff --git a/applications/ColossalEval/colossal_eval/dataset/cmmlu.py b/applications/ColossalEval/colossal_eval/dataset/cmmlu.py
new file mode 100644
index 0000000000000000000000000000000000000000..51f8ca14e0c87b52bdc58a413b64add5749ec460
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/dataset/cmmlu.py
@@ -0,0 +1,144 @@
+import copy
+import csv
+import os
+from typing import Dict, List
+
+from colossalai.logging import DistributedLogger
+
+from .base import BaseDataset
+
+cmmlu_subject_mapping = {
+ "agronomy": "农学",
+ "anatomy": "解剖学",
+ "ancient_chinese": "古汉语",
+ "arts": "艺术学",
+ "astronomy": "天文学",
+ "business_ethics": "商业伦理",
+ "chinese_civil_service_exam": "中国公务员考试",
+ "chinese_driving_rule": "中国驾驶规则",
+ "chinese_food_culture": "中国饮食文化",
+ "chinese_foreign_policy": "中国外交政策",
+ "chinese_history": "中国历史",
+ "chinese_literature": "中国文学",
+ "chinese_teacher_qualification": "中国教师资格",
+ "clinical_knowledge": "临床知识",
+ "college_actuarial_science": "大学精算学",
+ "college_education": "大学教育学",
+ "college_engineering_hydrology": "大学工程水文学",
+ "college_law": "大学法律",
+ "college_mathematics": "大学数学",
+ "college_medical_statistics": "大学医学统计",
+ "college_medicine": "大学医学",
+ "computer_science": "计算机科学",
+ "computer_security": "计算机安全",
+ "conceptual_physics": "概念物理学",
+ "construction_project_management": "建设工程管理",
+ "economics": "经济学",
+ "education": "教育学",
+ "electrical_engineering": "电气工程",
+ "elementary_chinese": "小学语文",
+ "elementary_commonsense": "小学常识",
+ "elementary_information_and_technology": "小学信息技术",
+ "elementary_mathematics": "初等数学",
+ "ethnology": "民族学",
+ "food_science": "食品科学",
+ "genetics": "遗传学",
+ "global_facts": "全球事实",
+ "high_school_biology": "高中生物",
+ "high_school_chemistry": "高中化学",
+ "high_school_geography": "高中地理",
+ "high_school_mathematics": "高中数学",
+ "high_school_physics": "高中物理学",
+ "high_school_politics": "高中政治",
+ "human_sexuality": "人类性行为",
+ "international_law": "国际法学",
+ "journalism": "新闻学",
+ "jurisprudence": "法理学",
+ "legal_and_moral_basis": "法律与道德基础",
+ "logical": "逻辑学",
+ "machine_learning": "机器学习",
+ "management": "管理学",
+ "marketing": "市场营销",
+ "marxist_theory": "马克思主义理论",
+ "modern_chinese": "现代汉语",
+ "nutrition": "营养学",
+ "philosophy": "哲学",
+ "professional_accounting": "专业会计",
+ "professional_law": "专业法学",
+ "professional_medicine": "专业医学",
+ "professional_psychology": "专业心理学",
+ "public_relations": "公共关系",
+ "security_study": "安全研究",
+ "sociology": "社会学",
+ "sports_science": "体育学",
+ "traditional_chinese_medicine": "中医中药",
+ "virology": "病毒学",
+ "world_history": "世界历史",
+ "world_religions": "世界宗教",
+}
+
+default_inference_kwargs = {
+ "calculate_loss": True,
+ "all_classes": ["A", "B", "C", "D"],
+ "language": "Chinese",
+ "pretrain": False,
+ "max_new_tokens": 32,
+}
+
+
+def get_few_shot_data(data: List[Dict]):
+ few_shot_data = []
+ for i in data:
+ few_shot_data.append(i["input"] + i["target"])
+ return few_shot_data
+
+
+class CMMLUDataset(BaseDataset):
+ """
+ Dataset class for CMMLU dataset.
+ Data source: https://github.com/haonan-li/CMMLU/tree/master/data
+ This dataset class will convert the original dataset into the inference dataset.
+ """
+
+ @staticmethod
+ def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
+ dataset = {"dev": {}, "test": {}}
+ for split in ["dev", "test"]:
+ files = os.listdir(os.path.join(path, split))
+ files.sort()
+
+ for file in files:
+ subject = file[0 : -len(".csv")]
+ subject = cmmlu_subject_mapping[subject]
+
+ file_dir = os.path.join(path, split, file)
+
+ dataset[split][subject] = {"data": []}
+
+ # It's been tested that each data sample in one subcategory have same inference arguments.
+ dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs)
+
+ if split == "test" and few_shot:
+ dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data(
+ dataset["dev"][subject]["data"]
+ )
+
+ with open(file_dir, encoding="utf-8") as f:
+ reader = csv.reader(f)
+ _ = next(reader)
+ for row in reader:
+ assert len(row) == 7
+ choices = f"A. {row[2]}\nB. {row[3]}\nC. {row[4]}\nD. {row[5]}"
+ data_sample = {
+ "dataset": "cmmlu",
+ "split": split,
+ "category": subject,
+ "instruction": f"以下是关于{subject}的单项选择题,请直接给出正确答案的选项。",
+ "input": f"题目:{row[1]}\n{choices}\n答案:",
+ "output": "",
+ "target": row[6],
+ }
+
+ dataset[split][subject]["data"].append(data_sample)
+
+ return dataset
diff --git a/applications/ColossalEval/colossal_eval/dataset/colossalai.py b/applications/ColossalEval/colossal_eval/dataset/colossalai.py
new file mode 100644
index 0000000000000000000000000000000000000000..54ea478ae5d65ea9de0c3f84008d31539d25b62c
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/dataset/colossalai.py
@@ -0,0 +1,70 @@
+from collections import defaultdict
+from copy import deepcopy
+from typing import Dict, List
+
+from colossal_eval.utils import jload
+
+from colossalai.logging import DistributedLogger
+
+from .base import BaseDataset
+
+default_inference_kwargs = {
+ "calculate_loss": False,
+ "all_classes": None,
+ "language": "Chinese",
+ "pretrain": False,
+ "max_new_tokens": 256,
+}
+
+# You can add your own subcategory questions and specify whether it is a single-choice question or has target answers and need to calculate loss.
+single_choice_question = set()
+calculate_loss = set()
+
+
+def get_data_per_category(data):
+ data_per_category = defaultdict(list)
+ for item in data:
+ category = item["category"]
+ data_per_category[category].append(item)
+
+ return data_per_category
+
+
+class ColossalDataset(BaseDataset):
+ """
+ Dataset class for Colossal dataset.
+ This dataset class will convert the original dataset into the inference dataset.
+ """
+
+ @staticmethod
+ def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
+ dataset = {"test": {}}
+ data = jload(path)
+ data_per_category = get_data_per_category(data)
+ categories = list(data_per_category.keys())
+
+ for category in categories:
+ dataset["test"][category] = {"data": []}
+ category_data = data_per_category[category]
+
+ dataset["test"][category]["inference_kwargs"] = deepcopy(default_inference_kwargs)
+
+ if category in calculate_loss:
+ dataset["test"][category]["inference_kwargs"]["calculate_loss"] = True
+ if category in single_choice_question:
+ dataset["test"][category]["inference_kwargs"]["all_classes"] = ["A", "B", "C", "D"]
+
+ for item in category_data:
+ data_sample = {
+ "dataset": "colossal",
+ "split": "test",
+ "category": category,
+ "instruction": item["instruction"],
+ "input": item["input"],
+ "output": "",
+ "target": item["target"],
+ "id": item["id"],
+ }
+ dataset["test"][category]["data"].append(data_sample)
+
+ return dataset
diff --git a/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py b/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bf0639e48829b230813e0fc7387e6165803abc4
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py
@@ -0,0 +1,122 @@
+import json
+import os
+import re
+from copy import deepcopy
+from typing import Dict, List
+
+from colossalai.logging import DistributedLogger
+
+from .base import BaseDataset
+
+multi_choice_datasets = [
+ "Chinese Lang and Usage MCQs",
+ "Chinese Modern Lit",
+ "English Fill in Blanks",
+ "English Reading Comp",
+ "Geography MCQs",
+ "Physics MCQs",
+ "English Cloze Test",
+]
+
+chinese_qa_datasets = [
+ "Biology MCQs",
+ "Chemistry MCQs",
+ "Chinese Lang and Usage MCQs",
+ "Chinese Modern Lit",
+ "Geography MCQs",
+ "History MCQs",
+ "Math I MCQs",
+ "Math II MCQs",
+ "Physics MCQs",
+ "Political Science MCQs",
+]
+english_qa_datasets = ["English MCQs", "English Fill in Blanks", "English Reading Comp", "English Cloze Test"]
+
+default_inference_kwargs = {
+ "calculate_loss": True,
+ "all_classes": None,
+ "language": "Chinese",
+ "pretrain": False,
+ "max_new_tokens": 32,
+}
+
+
+def get_all_classes(instruction: str):
+ letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ pattern = r"([A-Z]\. |[A-Z].|[A-Z]\.)"
+ options = sorted(list(set(re.findall(pattern, instruction))))
+ options = sorted(list(set([string[0] for string in options])))
+
+ for i in range(len(options)):
+ if options[i] == letters[i]:
+ continue
+ else:
+ return options[0:i]
+ return options
+
+
+class GaoKaoBenchDataset(BaseDataset):
+ """
+ Dataset class for GAOKAO-Bench dataset.
+ Data source: https://github.com/OpenLMLab/GAOKAO-Bench/tree/main/data
+ This dataset class will convert the original dataset into the inference dataset.
+
+ A few typos needed to be manually corrected in the origin dataset, some of the following is fixed.
+ Issue link: https://github.com/OpenLMLab/GAOKAO-Bench/issues/20
+ 1. Option C missing in index 111 in 2010-2022_Chemistry_MCQs.json
+ 2. Option B missing "." after it in index 16 in 2012-2022_English_Cloze_Test.json
+ 3. Option G missing "." after it in index 23 in 2012-2022_English_Cloze_Test.json
+ """
+
+ @staticmethod
+ def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
+ dataset = {"test": {}}
+ for category in ["Fill-in-the-blank_Questions", "Multiple-choice_Questions", "Open-ended_Questions"]:
+ files = os.listdir(os.path.join(path, "data", category))
+ files.sort()
+
+ for file in files:
+ subject = file[10:-5].split("_")
+ subject = " ".join(subject)
+ dataset["test"][subject] = {"data": []}
+
+ file_dir = os.path.join(path, "data", category, file)
+
+ with open(file_dir, encoding="utf-8") as f:
+ data = json.load(f)
+
+ # It's been tested that each data sample in one subcategory have same inference arguments.
+ inference_kwargs = deepcopy(default_inference_kwargs)
+ if category == "Multiple-choice_Questions" and subject not in multi_choice_datasets:
+ all_classes = get_all_classes(data["example"][0]["question"])
+ inference_kwargs["all_classes"] = all_classes
+ if subject in english_qa_datasets:
+ inference_kwargs["language"] = "English"
+ if subject in chinese_qa_datasets:
+ inference_kwargs["language"] = "Chinese"
+
+ dataset["test"][subject]["inference_kwargs"] = inference_kwargs
+
+ for sample in data["example"]:
+ # Convert multi-choice answers to a single string.
+ # We will convert it back when evaluating.
+ # We do this because if target is a list, it should be only used for multiple target answers.
+ if subject in multi_choice_datasets:
+ sample["answer"] = "".join(sample["answer"])
+
+ if isinstance(sample["answer"], list) and len(sample["answer"]) == 1:
+ sample["answer"] = sample["answer"][0]
+
+ data_sample = {
+ "dataset": "gaokaobench",
+ "split": "test",
+ "category": f"{category[:-10]}-{subject}",
+ "instruction": sample["question"].strip() + "\n答案:",
+ "input": "",
+ "output": "",
+ "target": sample["answer"],
+ }
+
+ dataset["test"][subject]["data"].append(data_sample)
+
+ return dataset
diff --git a/applications/ColossalEval/colossal_eval/dataset/longbench.py b/applications/ColossalEval/colossal_eval/dataset/longbench.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ea5e3c7d77fc9f8ee8fd496b67fefcdf625257a
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/dataset/longbench.py
@@ -0,0 +1,120 @@
+import os
+from copy import deepcopy
+from typing import Dict, List
+
+from colossal_eval.utils import get_json_list
+
+from colossalai.logging import DistributedLogger
+
+from .base import BaseDataset
+
+dataset2prompt = {
+ "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
+ "qasper": 'You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write "unanswerable". If the question is a yes/no question, answer "yes", "no", or "unanswerable". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write "unanswerable". If the question is a yes/no question, answer "yes", "no", or "unanswerable". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:',
+ "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:",
+ "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
+ "dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:",
+ "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:",
+ "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:",
+ "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:",
+ "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:",
+ "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}",
+ "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}",
+ "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}",
+ "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}",
+ "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ",
+ "passage_retrieval_en": 'Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like "Paragraph 1", "Paragraph 2", etc.\n\nThe answer is: ',
+ "passage_retrieval_zh": '以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是"段落1","段落2"等格式\n\n答案是:',
+ "lcc": "Please complete the code given below. \n{context}Next line of code:\n",
+ "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n",
+}
+
+dataset2maxlen = {
+ "narrativeqa": 128,
+ "qasper": 128,
+ "multifieldqa_en": 64,
+ "multifieldqa_zh": 64,
+ "hotpotqa": 32,
+ "2wikimqa": 32,
+ "musique": 32,
+ "dureader": 128,
+ "gov_report": 512,
+ "qmsum": 512,
+ "multi_news": 512,
+ "vcsum": 512,
+ "trec": 64,
+ "triviaqa": 32,
+ "samsum": 128,
+ "lsht": 64,
+ "passage_count": 32,
+ "passage_retrieval_en": 32,
+ "passage_retrieval_zh": 32,
+ "lcc": 64,
+ "repobench-p": 64,
+}
+
+default_inference_kwargs = {
+ "calculate_loss": True,
+ "all_classes": None,
+ "language": "Chinese",
+ "pretrain": False,
+ "max_new_tokens": 32,
+}
+
+
+class LongBenchDataset(BaseDataset):
+ """
+ Dataset class for LongBench dataset.
+ Data source: https://huggingface.co/datasets/THUDM/LongBench
+ This dataset class will convert the original dataset into the inference dataset.
+
+ Issue link: https://github.com/THUDM/LongBench/issues/15 (fixed)
+ There are duplicate target answers in `nq.jsonl`, but this doesn't affect evaluation results.
+ Also doesn't affect perplexity calculation (the program only need to select the minimum loss).
+ """
+
+ @staticmethod
+ def load(path: str, logger: DistributedLogger) -> List[Dict]:
+ dataset = {"test": {}}
+
+ files = os.listdir(path)
+ files.sort()
+
+ for file in files:
+ category = file[0:-6]
+
+ if category.endswith("_e"):
+ continue
+
+ dataset["test"][category] = {"data": []}
+
+ file_dir = os.path.join(path, file)
+
+ loaded_jsonl = get_json_list(file_dir)
+
+ # It's been tested that each data sample in one subcategory have same inference arguments.
+ inference_kwargs = deepcopy(default_inference_kwargs)
+ if loaded_jsonl[0]["all_classes"] is not None:
+ inference_kwargs["all_classes"] = loaded_jsonl[0]["all_classes"]
+ inference_kwargs["max_new_tokens"] = dataset2maxlen[category]
+ dataset["test"][category]["inference_kwargs"] = inference_kwargs
+
+ for sample in loaded_jsonl:
+ prompt = dataset2prompt[category].format(**sample)
+
+ data_sample = {
+ "dataset": "longbench",
+ "split": "test",
+ "category": category,
+ "instruction": prompt,
+ "input": "",
+ "output": "",
+ "target": sample["answers"],
+ }
+
+ dataset["test"][category]["data"].append(data_sample)
+
+ return dataset
diff --git a/applications/ColossalEval/colossal_eval/dataset/mmlu.py b/applications/ColossalEval/colossal_eval/dataset/mmlu.py
new file mode 100644
index 0000000000000000000000000000000000000000..b89c0a13cff1d82d5435ed52f446fd4092bd5093
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/dataset/mmlu.py
@@ -0,0 +1,73 @@
+import copy
+import csv
+import os
+from typing import Dict, List
+
+from colossalai.logging import DistributedLogger
+
+from .base import BaseDataset
+
+default_inference_kwargs = {
+ "calculate_loss": True,
+ "all_classes": ["A", "B", "C", "D"],
+ "language": "English",
+ "pretrain": False,
+ "max_new_tokens": 32,
+}
+
+
+def get_few_shot_data(data: List[Dict]):
+ few_shot_data = []
+ for i in data:
+ few_shot_data.append(i["input"] + i["target"])
+ return few_shot_data
+
+
+class MMLUDataset(BaseDataset):
+ """
+ Dataset class for MMLU dataset.
+ Data source: https://github.com/hendrycks/test
+ This dataset class will convert the original dataset into the inference dataset.
+ """
+
+ @staticmethod
+ def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
+ dataset = {"dev": {}, "test": {}}
+ for split in ["dev", "test"]:
+ files = os.listdir(os.path.join(path, split))
+ files.sort()
+
+ for file in files:
+ subject = file[0 : -len(f"_{split}.csv")].split("_")
+ subject = " ".join([word.title() if word != "us" else "US" for word in subject])
+
+ file_dir = os.path.join(path, split, file)
+
+ dataset[split][subject] = {"data": [], "inference_kwargs": {}}
+
+ # It's been tested that each data sample in one subcategory have same inference arguments.
+ dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs)
+
+ if split == "test" and few_shot:
+ dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data(
+ dataset["dev"][subject]["data"]
+ )
+
+ with open(file_dir, encoding="utf-8") as f:
+ reader = csv.reader(f)
+ for row in reader:
+ assert len(row) == 6
+ choices = f"A. {row[1]}\nB. {row[2]}\nC. {row[3]}\nD. {row[4]}"
+ data_sample = {
+ "dataset": "mmlu",
+ "split": split,
+ "category": subject,
+ "instruction": f"The following is a single-choice question on {subject}. Answer the question by replying A, B, C or D.",
+ "input": f"Question: {row[0]}\n{choices}\nAnswer: ",
+ "output": "",
+ "target": row[5],
+ }
+
+ dataset[split][subject]["data"].append(data_sample)
+
+ return dataset
diff --git a/applications/ColossalEval/colossal_eval/evaluate/GPT Evaluation.md b/applications/ColossalEval/colossal_eval/evaluate/GPT Evaluation.md
new file mode 100644
index 0000000000000000000000000000000000000000..37fbda4c8647322c0cd9db0aea7ca7139771b4a4
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/evaluate/GPT Evaluation.md
@@ -0,0 +1,248 @@
+# GPT Evaluation
+## Table of Contents
+- [Overview](#overview)
+- [GPT Evaluation](#gpt-evaluation)
+ - [Evaluation Category](#evaluation-category)
+ - [Evaluation Category Examples](#evaluation-category-examples)
+ - [Evaluation Metrics](#evaluation-metrics)
+- [Evaluation Process](#evaluation-process)
+ - [Data Format](#data-format)
+ - [Prompt](#prompt)
+ - [Battle Prompt](#battle-prompt)
+ - [Evaluation Prompt](#evaluation-prompt)
+ - [Evaluation](#evaluation)
+ - [Configuration](#configuration)
+ - [Evaluate](#evaluate)
+- [FAQ](#faq)
+- [Citations](#citations)
+
+
+## Overview
+
+In this directory, we introduce how you can evaluate your model using GPTs. It is now available for evaluation of both Chinese and English capability and we provide the following functions:
+
+* Compare the performance of two different models (battle).
+* Rate the model according to pre-defined metrics using prompting design.
+* Rate the model according to pre-defined metrics with additional reference answer using prompting design.
+
+## GPT Evaluation
+
+### Evaluation Category
+
+Our evaluation pipeline can examine the model's capability using different categories of questions. The following table includes some example categories. You can add your own questions.
+
+| Evaluation Category | Description |
+| :-----------------: | :----------------------------------------------------------- |
+| Brainstorming | Models are asked to generate a range of creative and diverse ideas according to the question. The capability of creativity is required. |
+| Chat | Models are asked to continue a multi-round dialogue given the roles involved. The capability of understanding, memorizing previous rounds of the dialogue and answering according to the persona provided is required. |
+| Generation | Models are asked to generate an email, letter, article, etc. The capability of generating texts in a high quality and human-written way is required. |
+| Open QA | Models are asked to answer an open QA question(without context provided). The capability of answering questions with the models' own knowledge base is required. |
+| Roleplay | Models are asked to play the role provided. The capability of engaging in the scenario and effectively interacting with the user is required. |
+
+
+### Evaluation Category Examples
+To better understand each evaluation category, here are some example questions provided. Example questions are in the `configs/gpt_evaluation/data` folder.
+
+
+| Evaluation Category | Chinese Example | English Example |
+| :-----------------: | :----------------------------------------------------------- | :----------------------------------------------------------- |
+| Brainstorming | 列举一些可以促进头发生长的食物。 | How do you properly chop an onion without crying? |
+| Chat | 基于以下角色信息完成一段对话。小张是一名新手爱好者,对养鸡有浓厚的兴趣。老李是一名有丰富经验的养鸡大师。
小张:您好,老李,我最近开始对养鸡感兴趣了,想请教您一些问题。
老李:你好,小张,我很乐意帮助你。你想问些什么?
小张:我想知道如何确定鸡的品种和性别?
老李:确切的品种可以通过鸡的外貌特征来确定,而性别一般是通过鸡卵的大小和形状来判断。还有什么问题吗?
小张:
| Complete a dialogue based on the following character information. Alex: A novice writer who is struggling to find inspiration and develop his writing skills. Emma: A successful author with many published works, providing guidance and advice to Alex.
Alex: Hi Emma, I have been writing for a while now but can't seem to make any progress. Can you give me any advice?
Emma: Hi Alex, sure. What kind of writing are you doing?
Alex: I'm trying to write a novel, but I just can't seem to find any inspiration.
Emma:
|
+| Generation | 请为一家咖啡店编写一篇简短的广告语,吸引更多的顾客。 | Write a set of guidelines for first-time pet owners on how to properly care for a new puppy. |
+| Open QA | 解释什么是RNA病毒和DNA病毒。 | Explain the process of osmosis in biological systems. |
+| Roleplay | 我要你把我写的句子翻译成表情符号。我会写句子,你会用表情符号表达它。我只是想让你用表情符号来表达它。除了表情符号,我不希望你回复任何内容。当我需要用中文告诉你一些事情时,我会用 {} 这样的大括号括起来。我的第一句话是“{我的职业是消防员。}” | I want you to act as a rapper. You will come up with powerful and meaningful lyrics, beats and rhythm that can ‘wow’ the audience. Your lyrics should have an intriguing meaning and message which people can relate too. When it comes to choosing your beat, make sure it is catchy yet relevant to your words, so that when combined they make an explosion of sound everytime! My first request is "I need a rap song about finding strength within yourself." |
+
+### Evaluation Metrics
+
+GPT evaluation uses GPT models to evaluate the prediction of different models and different pre-defined evaluation metrics are applied to different categories. The following table shows the 10 pre-defined evaluation metrics both in Chinese and English:
+
+| Evaluation Metric | Prompt Words | CoT(Chain-of-Thought) |
+| :-------------------: | :----------------------------------------------------------- | :----------------------------------------------------------- |
+| 语言组织
(Language organization) | 语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc. | 1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。
2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说
3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。
4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。
5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。
6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.
2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.
3. Determine if the answer is relevant to the question or topic and conveys a clear message.
4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.
5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.
6. Evaluate the linguistic organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good linguistic organization and 1 indicates very poor linguistic organization. |
+| 切题
(Relevance) | 切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic. | 1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。
2. 阅读答案,确认答案是否直接回答了题目所问的问题。
3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。
4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。1. Read the question to determine what the question asks and what aspects of the question need to be answered.
2. Read the answers to make sure that they directly answer the question asked.
3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.
4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all. |
+| 创意性
(Creativity) | 创意性(1-5):某些头脑风暴问题可能需要答案具有创意,提出新的思路。Creativity (1-5): Some brainstorming questions may require answers that are creative and suggest new ideas. | 1. 仔细阅读所提供的头脑风暴问题,确保你理解问题的要点和背景。
2. 根据你的知识和经验,判断所提供的答案是否可行。如果答案不可行,则创意性评分可能会受到影响。
3. 考虑答案中是否包含新颖的想法或独特的思路。答案可能与已知的解决方案有所重叠,但仍然可以被认为是有创意的,只要它提供了新的角度或方法来解决问题。
4. 根据答案的创意性,给出一个1到5的评分。如果答案缺乏创意,则应给出一个较低的评分。如果答案具有创意并提供了新的思路,应给出一个较高的评分。1. Read the provided brainstorming questions carefully to make sure you understand the gist and context of the questions.
2. Based on your knowledge and experience, determine if the answers provided are feasible. If the answer is not feasible, the creativity score may be affected.
3. Consider whether the answer contains novel ideas or unique thoughts. An answer may overlap with a known solution and still be considered creative, as long as it offers a new perspective or approach to the problem.
4. Give a score of 1 to 5 depending on the creativity of the answer. If the answer lacks creativity, a lower score should be given. If the answer is creative and provides a new idea, a higher score should be given. |
+| 实用性
(Practicality) | 实用性(1-5):某些头脑风暴问题可能需要答案提出实用的建议或解决方法。Practicality (1-5): Some brainstorming questions may require answers to suggest practical suggestions or solutions. | 1. 仔细阅读所提供的头脑风暴问题,确保你理解问题的要点和背景。
2. 根据你的知识和经验,判断所提供的答案是否可行。如果答案不可行,则实用性评分可能会受到影响。
3. 考虑答案中提出的建议或解决方法是否实用并可行。答案可能看起来很好,但如果无法实现或应用,则实用性评分可能会受到影响。
4. 根据答案的实用性,给出一个1到5的评分。如果答案缺乏实用性,则应给出一个较低的评分。如果答案提出了实用的建议或解决方法,并且可以很好地解决问题,则应给出一个较高的评分。1. Read the provided brainstorming questions carefully to make sure you understand the gist and context of the questions.
2. Based on your knowledge and experience, determine if the answers provided are feasible. If the answer is not feasible, the practicality score may be affected.
3. Consider whether the suggestions or solutions presented in the answer are practical and workable. The answer may look good, but if it cannot be implemented or applied, the practicality score may be affected.
4. Give a score of 1 to 5 depending on the practicality of the answer. If the answer lacks practicality, a lower score should be given. If the answer makes a practical suggestion or solution and solves the problem well, a higher score should be given. |
+| 正确性
(Correctness) | 正确性(1-5):正确性(1-5):答案是否正确。 Correctness (1-5): whether the answer is correct or not. | 1. 仔细阅读题目,尝试自己回答该问题。
2. 检查答案的准确性。您可以使用已知的事实或研究来验证答案是否正确。如果答案是正确的,则可以将正确性得分为5分。如果答案是部分正确的,则可以给予适当的得分,例如2分、3分或4分。如果答案完全不正确,则只得1分。
1. Read the question carefully and try to answer the question yourself.
2. Check the correctness of the answer. You can use known facts or research to verify that the answer is correct. If the answer is correct, you can give a score of 5 for correctness. If the answer is partially correct, an appropriate score, such as 2, 3, or 4, may be given. If the answer is completely incorrect, only 1 point is awarded. |
+| 自然
(Naturalness) | 自然(1-5):答案是否自然,并且符合问题给定的身份。Naturalness (1-5): whether the answer is natural and fits the identity given by the question. | 1. 阅读题目,确定题目提供的身份信息。
2. 检查答案内容是否符合题目给定的身份。
3. 根据以上因素,对该回答的自然性进行打分,分数从1到5,其中1表示不自然,5表示非常自然,并符合问题给定的身份。1. Read the question and determine the identity information provided in the question.
2. Check whether the content of the answer matches the identity given in the question.
3. Based on the above factors, score the naturalness of the response on a scale from 1 to 5, where 1 means unnatural and 5 means very natural and in accordance with the identity given in the question. |
+| 参与感
(Engagingness) | 参与感(1-5):答案是否对前面的对话内容做出了恰当的反应,是否理解对话的语境和背景。Engagingness (1-5): whether the answer responds appropriately to the content of the preceding conversation and whether it understands the context and background of the conversation. | 1. 阅读题目,确定对话的语境和背景。
2. 检查答案是否充分理解对话的语境和背景,能否自然地融入到对话中而不显得突兀。
3. 根据以上因素,对该回答的参与感进行打分,分数从1到5,其中1表示没有参与感,5表示非常有参与感,并且恰当地理解了对话的语境和背景。1. Read the questions to determine the context and background of the dialogue.
2. Check that the answer fully understands the context and background of the conversation and that it fits naturally into the conversation without seeming abrupt.
3. Based on the above factors, rate the response's engagement on a scale from 1 to 5, where 1 means not engaged and 5 means very engaged and appropriately understands the context and background of the conversation. |
+| 合理性
(Reasonableness) | 合理性(1-5):答案是否能够与前面的对话内容形成逻辑上的衔接,是否符合常理,能否在这个上下文中合理存在。Reasonableness (1-5): Whether the answer can form a logical connection with the content of the previous dialogue, whether it is consistent with common sense, and whether it can reasonably exist in this context. | 1. 阅读题目,确定对话的主题以及问题期望的回答方向。
2. 判断答案是否能够与前面的对话内容形成逻辑上的衔接,是否符合常理,能否在这个上下文中合理存在。
3. 根据以上因素,对该回答的合理性进行打分,分数从1到5,其中1表示不合理,5表示非常合理,并且能够与前面的对话内容形成逻辑上的衔接,并符合常理。1. Read the question and determine the topic of the conversation and the direction the question expects the answer to go.
2. Determine whether the answer can be logically connected to the preceding conversation, whether it makes common sense, and whether it can reasonably exist in this context.
3. Based on the above factors, rate the reasonableness of the answer on a scale from 1 to 5, where 1 means unreasonable and 5 means very reasonable and able to form a logical connection with the preceding dialogue content and consistent with common sense. |
+| 多样性
(Diversity) | 多样性(1-5):答案使用语言是否优美,具有有一定的创造性和想象力。然而,回答也应该保持合理和适度,不要过于夸张或离题。Diversity (1-5): Whether the answers use beautiful language and have some creativity and imagination. However, answers should also be kept reasonable and moderate, not overly exaggerated or off-topic. | 1. 仔细阅读整个回答,确保完全理解回答所表达的内容和主题。
2. 在阅读回答的同时,注意语言的质量,例如措辞是否正确,语言是否生动等。
3. 检查回答的创造性和想象力,看看回答是否能够吸引人阅读下去。
4. 检查回答的合理性和适度,看看回答是否夸张或离题。5. 将多样性的评分打分在1到5之间,5分表示回答的质量很好,能够吸引人阅读,1分表示回答的内容生硬或者有离题的问题。1. Read the entire response carefully to ensure that you fully understand the content and theme expressed in the response.
2. While reading the response, pay attention to the quality of the language, such as whether the wording is correct and the language is vivid.
3. Check the creativity and imagination of the response to see if the response is engaging to read on.
4. Check the reasonableness and appropriateness of the responses to see if the responses are exaggerated or off-topic.
5. Rate the diversity on a scale of 1 to 5, with a 5 indicating a good quality response that is engaging to read and a 1 indicating a raw response or a question that is off-topic. |
+| 保真度
(Fidelity) | 保真度(1-5):答案是否能够严格遵守角色的设定回答给定的请求。Fidelity (1-5): whether the answer is able to answer the given request in strict compliance with the role setting. | 1. 仔细阅读问题,了解角色在问题中的设定和表现,包括职业、背景、观点、性格等方面。
阅读题目的请求,确认回答请求时需要注意的细节。
3. 对比提供的回答与该角色的设定,评估回答是否能够严格遵守角色的设定。
4. 结合以上评估结果给出保真度的评分,范围从1到5分,其中1分表示回答与角色设定完全不符,5分表示回答完全符合角色设定且满足给定请求。1. Read the question carefully to understand how the character is set up and represented in the question, including aspects such as occupation, background, point of view, and personality.
2. Read the question's request and confirm the details that need to be taken into account when answering the request.
3. Compare the provided answer with the setting of the role and assess whether the answer can strictly adhere to the setting of the role.
4. Combine the results of the above assessment to give a fidelity score ranging from 1 to 5, where a score of 1 means that the response does not match the persona at all, and a score of 5 means that the response fully complies with the persona and satisfies the given request. |
+
+GPT models evaluate the quality of model predictions based on the given prompt words and gives a score between 1-5.
+
+> **NOTE 1:** You can find all the prompt words and CoT(Chain-of-Thought) in `configs/gpt_evaluation/prompt/evaluation_prompt`.
+
+> **NOTE 2:** To add customized metrics, you can refer to [FAQ](#faq).
+
+## Evaluation Process
+
+### Data Format
+
+A JSON file contains one list. Each element in the list is a target answer / prediction record for one instruction / question.
+An element should have the following fields:
+
+* `category` (str, compulsory): The category of the instruction / question.
+* `instruction` (str, compulsory): The instruction / question for the LLM.
+* `input` (str, optional): The additional context of the instruction / question.
+* `output` (str, optional): The model output of the instruction, models will fill in this field during inference time.
+* `target` (str, optional): The target answer for the instruction.
+* `id` (int, compulsory): The ID of the instruction / question.
+
+Example:
+
+```json
+[
+ {
+ "category": "brainstorming",
+ "instruction": "请问如何制作一份美味的西红柿炒鸡蛋?",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 1
+ },
+ {
+ "category": "chat",
+ "instruction": "基于以下角色信息完成一段对话。小张是一名新手爱好者,对养鸡有浓厚的兴趣。老李是一名有丰富经验的养鸡大师。",
+ "input": "小张:您好,老李,我最近开始对养鸡感兴趣了,想请教您一些问题。 老李:你好,小张,我很乐意帮助你。你想问些什么? 小张:我想知道如何确定鸡的品种和性别? 老李:确切的品种可以通过鸡的外貌特征来确定,而性别一般是通过鸡卵的大小和形状来判断。还有什么问题吗? 小张:",
+ "output": "",
+ "target": "",
+ "id": 2
+ }
+]
+```
+
+### Prompt
+
+#### Battle Prompt
+
+The following is the Chinese battle prompt. In the battle prompt, the question and answers from two different models are fed into the prompt template. You can find example battle prompt files for Chinese and English in `configs/gpt_evaluation/prompt/battle_prompt`.
+
+```json
+{
+ "id": 1,
+ "system_prompt": "你是一个检查回答质量的好助手。",
+ "prompt_template": "[问题]\n{question}\n\n[1号AI助手的答案]\n{answer_1}\n\n[1号AI助手答案终止]\n\n[2号AI助手的答 案]\n{answer_2}\n\n[2号AI助手答案终止]\n\n[要求]\n{prompt}\n\n",
+ "prompt": "我们需要你评价这两个AI助手回答的性能。\n请对他们的回答的有用性、相关性、准确性、详细程度进行评分。每个AI助手都会得到一个1到10分的总分,分数越高表示整体表现越好。\n请首先输出一行,该行只包含两个数值,分别表示1号和2号AI助手的分数。这两个分数之间要有一个空格。在随后的一行中,请对你的评价作出全面的解释,避免任何潜在的偏见,并确保AI助手回答的顺序不会影响您的判断。"
+}
+```
+
+#### Evaluation Prompt
+
+The following is an example of a Chinese GPT evaluation prompt. In an evaluation prompt, you should define your metrics in `metrics` and provide CoT(Chain-of-Thought) in `CoT`. You can find example evaluation prompt files for Chinese and English in `configs/gpt_evaluation/prompt/evaluation_prompt`.
+
+```json
+{
+ "brainstorming": {
+ "id": 1,
+ "category": "brainstorming",
+ "metrics": {
+ "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。"
+ },
+ "CoT": {
+ "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:"
+ },
+ "prompt": "你是一个好助手。请你为下面“头脑风暴”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}"
+ }
+}
+```
+
+`"metrics"`: the metrics that can be used in GPT evaluation. This field determines which metrics can be added to your config file.
+
+`"CoT"`: evaluation steps you prompt to GPT models for each metric defined in `"metrics"`.
+
+### Evaluation
+
+#### Configuration
+
+The following is an example of a Chinese config file. The configuration file can control how the pipeline evaluates the model. You need to specify GPT evaluation metrics in key `GPT`. You can find an example English config file in `configs/gpt_evaluation/config/config_en.json`.
+
+```json
+{
+ "language": "cn",
+ "category": {
+ "brainstorming": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "creativity",
+ "practicality",
+ "reasonableness"
+ ]
+ }
+ }
+}
+```
+
+`"language"`: the language used to evaluate the model capability. We only support Chinese `"cn"` for now.
+
+`"category"`: the category/categories needed to evaluate the model capability.
+
+`"GPT"`: the metrics you want to use for GPT evaluation.
+
+
+#### Evaluate
+
+After setting the configuration file, you can evaluate the model using `examples/gpt_evaluation/eval.py`. If you want to make comparisons between answers of two different models, you should specify two answer files in the argument `answer_file_list` and two model names in the argument `model_name_list`. If you want to evaluate one answer file, the length of both `answer_file_list` and `model_name_list` should be 1 and the program will perform evaluation using automatic metrics and GPT models.
+
+An example script is provided as follows:
+
+```shell
+python eval.py \
+ --config_file "path to the config file" \
+ --battle_prompt_file "path to the prompt file for battle" \
+ --gpt_evaluation_prompt_file "path to the prompt file for gpt evaluation" \
+ --target_file "path to the target answer file" \
+ --answer_file_list "path to the answer files of at most 2 models" \
+ --model_name_list "the names of at most 2 models" \
+ --gpt_model "which GPT model to use for evaluation" \
+ --save_path "path to save results" \
+ --openai_key "your openai key" \
+```
+
+If you want GPT evaluation with reference, you can add an argument `--gpt_with_reference`, but make sure the reference file have target answers.
+
+## FAQ
+
+How can I add a new GPT evaluation metric?
+
+For example, if you want to add a new metric `persuasiveness` into category `brainstorming`, you should add the metric definition and its corresponding CoT(Chain-of-thought) in the evaluation prompt file in `prompt/evaluation_promt`. The CoT can be generated using ChatGPT. You can prompt ChatGPT to generate evaluation steps for the new metric.
+
+```json
+{
+ "brainstorming": {
+ "id": 1,
+ "category": "brainstorming",
+ "metrics": {
+ "persuasiveness": "persuasiveness(1-5):a short description for persuasiveness"
+ },
+ "CoT": {
+ "persuasiveness": "CoT for persuasiveness\n\npersuasiveness:"
+ },
+ "prompt": "You are a good assistant. Please rate the given answer to the \"brainstorming\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}"
+ }
+}
+```
+
+
+
+## Citations
+
+```bibtex
+@misc{vicuna2023,
+ title = {Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90\%* ChatGPT Quality},
+ url = {https://vicuna.lmsys.org},
+ author = {Chiang, Wei-Lin and Li, Zhuohan and Lin, Zi and Sheng, Ying and Wu, Zhanghao and Zhang, Hao and Zheng, Lianmin and Zhuang, Siyuan and Zhuang, Yonghao and Gonzalez, Joseph E. and Stoica, Ion and Xing, Eric P.},
+ month = {March},
+ year = {2023}
+}
+
+@misc{liu2023geval,
+ title={G-Eval: NLG Evaluation using GPT-4 with Better Human Alignment},
+ author={Yang Liu and Dan Iter and Yichong Xu and Shuohang Wang and Ruochen Xu and Chenguang Zhu},
+ year={2023},
+ eprint={2303.16634},
+ archivePrefix={arXiv},
+ primaryClass={cs.CL}
+}
+```
diff --git a/tests/test_layers/test_1d/checks_1d/__init__.py b/applications/ColossalEval/colossal_eval/evaluate/__init__.py
similarity index 100%
rename from tests/test_layers/test_1d/checks_1d/__init__.py
rename to applications/ColossalEval/colossal_eval/evaluate/__init__.py
diff --git a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/__init__.py b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c5df09a6909b0b5fdc6c726e2543012b94a4410
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/__init__.py
@@ -0,0 +1,3 @@
+from .dataset_evaluator import DatasetEvaluator
+
+__all__ = ["DatasetEvaluator"]
diff --git a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/dataset_evaluator.py b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/dataset_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..c70988707a154cc50126797c5b13949acef3b97e
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/dataset_evaluator.py
@@ -0,0 +1,269 @@
+from typing import Dict, List
+
+import colossal_eval.evaluate.dataset_evaluator.metrics as metric_helper
+import numpy as np
+import tqdm
+
+LabelBasedMetrics = ["first_token_accuracy", "matthews_correlation"]
+LossBasedMetrics = ["perplexity", "ppl_score", "ppl_score_over_choices", "per_byte_perplexity", "per_byte_ppl_score"]
+CombinedMetrics = ["combined_single_choice_accuracy"]
+OtherMetrics = [
+ "f1_score",
+ "f1_zh_score",
+ "rouge_score",
+ "rouge_zh_score",
+ "retrieval_score",
+ "retrieval_zh_score",
+ "classification_score",
+ "code_sim_score",
+ "count_score",
+ "multi_choice_accuracy",
+ "math_equivalence",
+ "single_choice_accuracy",
+]
+
+
+class DatasetEvaluator(object):
+ """
+ Dataset evaluator.
+
+ """
+
+ def __init__(self):
+ pass
+
+ def _calculate_label_metrics(self, metric: str, category: str):
+ """Calculate label-based metrics."""
+ weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
+
+ str_label_map = {
+ choice: idx for idx, choice in enumerate(self.data[category]["inference_kwargs"]["all_classes"])
+ }
+
+ references = [str_label_map[sample["target"]] for sample in self.data[category]["data"]]
+ [sample["output"] for sample in self.data[category]["data"]]
+
+ flag = False
+ softmaxs = []
+ for i, sample in enumerate(self.data[category]["data"]):
+ if np.any(np.isnan(np.array(list(sample["softmax_over_choices"].values())))):
+ if not flag:
+ print(
+ f"NaN in the softmax, switch to exact match for category {category} in dataset {self.dataset_name} in model {self.model_name}."
+ )
+ flag = True
+ score = 0
+ for ref in sample["target"]:
+ score = max(
+ score,
+ metric_helper.single_choice_accuracy(
+ sample["output"], ref, all_classes=self.data[category]["inference_kwargs"]["all_classes"]
+ ),
+ )
+ softmaxs.append(references[i] if score == 1 else -1)
+ else:
+ softmaxs.append(np.argmax(np.array(list(sample["softmax_over_choices"].values()))))
+
+ references = np.array(references)
+ softmaxs = np.array(softmaxs)
+ scores = np.sum(references == softmaxs) / len(self.data[category]["data"]) * 100
+
+ self.evaluation_results[metric][category] = (scores, len(self.data[category]["data"]))
+ self.evaluation_results[metric]["ALL"] += scores * weight
+
+ def _calculate_combined_metrics(self, metric: str, category: str):
+ """Calculate combined metrics."""
+ weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
+
+ references = [sample["target"] for sample in self.data[category]["data"]]
+ predictions = [sample["output"] for sample in self.data[category]["data"]]
+
+ str_label_map = {
+ choice: idx for idx, choice in enumerate(self.data[category]["inference_kwargs"]["all_classes"])
+ }
+
+ references_labels = [str_label_map[sample["target"][0]] for sample in self.data[category]["data"]]
+ predictions = [sample["output"] for sample in self.data[category]["data"]]
+
+ flag = False
+ softmaxs = []
+ for i, sample in enumerate(self.data[category]["data"]):
+ if np.any(np.isnan(np.array(list(sample["softmax_over_choices"].values())))):
+ if not flag:
+ print(
+ f"NaN in the softmax, switch to exact match for category {category} in dataset {self.dataset_name} in model {self.model_name}."
+ )
+ flag = True
+ score = 0
+ for ref in sample["target"]:
+ score = max(
+ score,
+ metric_helper.single_choice_accuracy(
+ sample["output"], ref, all_classes=self.data[category]["inference_kwargs"]["all_classes"]
+ ),
+ )
+ softmaxs.append(references[i] if score == 1 else -1)
+ else:
+ softmaxs.append(np.argmax(np.array(list(sample["softmax_over_choices"].values()))))
+
+ metric_method = eval("metric_helper." + metric)
+
+ total_score = 0.0
+ for prediction, reference, references_label, softmax in zip(
+ predictions, references, references_labels, softmaxs
+ ):
+ score = 0.0
+
+ for ref in reference:
+ score = max(
+ score,
+ metric_method(prediction, ref, all_classes=self.data[category]["inference_kwargs"]["all_classes"]),
+ )
+ if references_label == softmax:
+ score = 1
+
+ total_score += score
+ total_score = total_score * 100 / len(self.data[category]["data"])
+
+ self.evaluation_results[metric][category] = (total_score, len(self.data[category]["data"]))
+ self.evaluation_results[metric]["ALL"] += total_score * weight
+
+ def _calculate_other_metrics(self, metric: str, category: str):
+ """Calculate other metrics."""
+ weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
+
+ references = [sample["target"] for sample in self.data[category]["data"]]
+ predictions = [sample["output"] for sample in self.data[category]["data"]]
+
+ metric_method = eval("metric_helper." + metric)
+
+ total_score = 0.0
+ for prediction, reference in zip(predictions, references):
+ score = 0.0
+ for ref in reference:
+ score = max(
+ score,
+ metric_method(prediction, ref, all_classes=self.data[category]["inference_kwargs"]["all_classes"]),
+ )
+ total_score += score
+ total_score = total_score * 100 / len(predictions)
+
+ self.evaluation_results[metric][category] = (total_score, len(self.data[category]["data"]))
+ self.evaluation_results[metric]["ALL"] += total_score * weight
+
+ def _calculate_loss_metrics(self, metric: str, category: str):
+ """Calculate perplexity."""
+ if metric == "perplexity":
+ weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
+ losses = [min(sample["loss"]) for sample in self.data[category]["data"]]
+ perplexity = np.mean(np.exp(np.array(losses)))
+
+ self.evaluation_results["perplexity"][category] = (perplexity, len(self.data[category]["data"]))
+ self.evaluation_results["perplexity"]["ALL"] += perplexity * weight
+ elif metric == "ppl_score":
+ weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
+ losses = [min(sample["loss"]) for sample in self.data[category]["data"]]
+ perplexity_score = np.mean(np.exp(-np.array(losses))) * 100
+
+ self.evaluation_results["ppl_score"][category] = (perplexity_score, len(self.data[category]["data"]))
+ self.evaluation_results["ppl_score"]["ALL"] += perplexity_score * weight
+ elif metric == "ppl_score_over_choices" and self.data[category]["inference_kwargs"]["all_classes"] is not None:
+ weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
+ loss_over_choices = [sample["loss_over_choices"] for sample in self.data[category]["data"]]
+ perplexity_score_over_choices = np.mean(np.exp(-np.array(loss_over_choices))) * 100
+
+ self.evaluation_results["ppl_score_over_choices"][category] = (
+ perplexity_score_over_choices,
+ len(self.data[category]["data"]),
+ )
+ self.evaluation_results["ppl_score_over_choices"]["ALL"] += perplexity_score_over_choices * weight
+ elif metric == "per_byte_perplexity":
+ weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
+ losses = [min(sample["loss_sum"]) for sample in self.data[category]["data"]]
+ perplexity = np.mean(np.exp(np.array(losses) / np.array(self.N_bytes[category])))
+
+ self.evaluation_results["per_byte_perplexity"][category] = perplexity
+ self.evaluation_results["per_byte_perplexity"]["ALL"] += perplexity * weight
+ elif metric == "per_byte_ppl_score":
+ weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
+ losses = [min(sample["loss_sum"]) for sample in self.data[category]["data"]]
+ perplexity_score = np.mean(np.exp(-np.array(losses) / np.array(self.N_bytes[category]))) * 100
+
+ self.evaluation_results["per_byte_ppl_score"][category] = perplexity_score
+ self.evaluation_results["per_byte_ppl_score"]["ALL"] += perplexity_score * weight
+
+ def _evaluate(self):
+ """Calculate and return evaluation results"""
+
+ for metric in self.metrics:
+ pbar = tqdm.tqdm(
+ desc=f"{self.dataset_name}-{metric}-{self.model_name}", total=len(self.suggested_categories[metric])
+ )
+
+ if metric in LabelBasedMetrics:
+ for category in self.suggested_categories[metric]:
+ self._calculate_label_metrics(metric, category)
+ pbar.update(1)
+ elif metric in LossBasedMetrics:
+ for category in self.suggested_categories[metric]:
+ self._calculate_loss_metrics(metric, category)
+ pbar.update(1)
+ elif metric in CombinedMetrics:
+ for category in self.suggested_categories[metric]:
+ self._calculate_combined_metrics(metric, category)
+ pbar.update(1)
+ elif metric in OtherMetrics:
+ for category in self.suggested_categories[metric]:
+ self._calculate_other_metrics(metric, category)
+ pbar.update(1)
+
+ return self.evaluation_results
+
+ def get_evaluation_results(self, data: List[Dict], dataset_name: str, model_name: str, metrics: List[str]):
+ """
+ Evaluate inference data on the given metrics.
+
+ Args:
+ data: Data to be evaluated.
+ dataset_name: Name of the dataset
+ model_name: Name of the model
+ metrics: Metrics used to evaluate.
+
+ """
+ self.data = data
+ self.dataset_name = dataset_name
+ self.model_name = model_name
+ self.categories = list(data.keys())
+ self.metrics = metrics
+
+ self.evaluation_results = {
+ metric: {category: 0 for category in (["ALL"] + self.categories)} for metric in self.metrics
+ }
+
+ self.total_length = 0
+ self.total_single_choices = 0
+ for value in self.data.values():
+ self.total_length += len(value["data"])
+ if value["inference_kwargs"]["all_classes"] is not None:
+ self.total_single_choices += len(value["data"])
+
+ self.metric_total_length = {metric: 0 for metric in self.metrics}
+ self.suggested_categories = {metric: [] for metric in self.metrics}
+
+ for metric in self.metrics:
+ self.suggested_categories[metric] = metric_helper.metrics4subcategory[self.dataset_name][metric]
+ if "ALL" in self.suggested_categories[metric]:
+ self.suggested_categories[metric] = self.categories
+ self.metric_total_length[metric] = self.total_length
+ continue
+ for category in self.suggested_categories[metric]:
+ self.metric_total_length[metric] += len(self.data[category]["data"])
+
+ if "per_byte_perplexity" in self.metrics or "per_byte_ppl_score" in self.metrics:
+ self.N_bytes = {category: [] for category in self.categories}
+ for category in self.categories:
+ samples = self.data[category]["data"]
+ for sample in samples:
+ self.N_bytes[category].append(sample["byte_num"][0])
+
+ return self._evaluate()
diff --git a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/metrics.py b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..914465478dec92be62117621989650eb8eed3904
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/metrics.py
@@ -0,0 +1,623 @@
+# Code adapted from https://github.com/THUDM/LongBench/blob/main/metrics.py
+# Code adapted from https://github.com/hendrycks/math/blob/main/modeling/math_equivalence.py
+# Code adapted from https://github.com/ruixiangcui/AGIEval/blob/main/src/evaluation.py
+
+import difflib
+import re
+import string
+from collections import Counter
+
+import jieba
+from fuzzywuzzy import fuzz
+from rouge import Rouge
+
+metrics4subcategory = {
+ "pretrain": {
+ "perplexity": ["ALL"],
+ "ppl_score": ["ALL"],
+ "per_byte_perplexity": ["ALL"],
+ "per_byte_ppl_score": ["ALL"],
+ },
+ # The commented are non 4-choice questions.
+ "agieval": {
+ "combined_single_choice_accuracy": [
+ # "lsat-ar",
+ # "lsat-lr",
+ # "lsat-rc",
+ "logiqa-en",
+ "sat-math",
+ "sat-en",
+ # "aqua-rat",
+ "sat-en-without-passage",
+ "gaokao-english",
+ "logiqa-zh",
+ "gaokao-chinese",
+ "gaokao-geography",
+ "gaokao-history",
+ "gaokao-biology",
+ "gaokao-chemistry",
+ ],
+ "first_token_accuracy": [
+ # "lsat-ar",
+ # "lsat-lr",
+ # "lsat-rc",
+ "logiqa-en",
+ "sat-math",
+ "sat-en",
+ # "aqua-rat",
+ "sat-en-without-passage",
+ "gaokao-english",
+ "logiqa-zh",
+ "gaokao-chinese",
+ "gaokao-geography",
+ "gaokao-history",
+ "gaokao-biology",
+ "gaokao-chemistry",
+ ],
+ "single_choice_accuracy": [
+ # "lsat-ar",
+ # "lsat-lr",
+ # "lsat-rc",
+ "logiqa-en",
+ "sat-math",
+ "sat-en",
+ # "aqua-rat",
+ "sat-en-without-passage",
+ "gaokao-english",
+ "logiqa-zh",
+ "gaokao-chinese",
+ "gaokao-geography",
+ "gaokao-history",
+ "gaokao-biology",
+ "gaokao-chemistry",
+ ],
+ "multi_choice_accuracy": ["jec-qa-kd", "jec-qa-ca", "gaokao-physics", "gaokao-mathqa"],
+ "math_equivalence": ["gaokao-mathcloze", "math"],
+ "perplexity": ["ALL"],
+ "ppl_score_over_choices": [
+ "lsat-ar",
+ "lsat-lr",
+ "lsat-rc",
+ "logiqa-en",
+ "sat-math",
+ "sat-en",
+ "aqua-rat",
+ "sat-en-without-passage",
+ "gaokao-english",
+ "logiqa-zh",
+ "jec-qa-kd",
+ "jec-qa-ca",
+ "gaokao-chinese",
+ "gaokao-geography",
+ "gaokao-history",
+ "gaokao-biology",
+ "gaokao-chemistry",
+ "gaokao-physics",
+ "gaokao-mathqa",
+ ],
+ "ppl_score": ["ALL"],
+ },
+ "cmmlu": {
+ "first_token_accuracy": ["ALL"],
+ "single_choice_accuracy": ["ALL"],
+ "perplexity": ["ALL"],
+ "ppl_score_over_choices": ["ALL"],
+ "ppl_score": ["ALL"],
+ },
+ "gaokaobench": {
+ "combined_single_choice_accuracy": [
+ "English MCQs",
+ "Biology MCQs",
+ "Chemistry MCQs",
+ "History MCQs",
+ "Math I MCQs",
+ "Math II MCQs",
+ "Political Science MCQs",
+ ],
+ "first_token_accuracy": [
+ "English MCQs",
+ "Biology MCQs",
+ "Chemistry MCQs",
+ "History MCQs",
+ "Math I MCQs",
+ "Math II MCQs",
+ "Political Science MCQs",
+ ],
+ "single_choice_accuracy": [
+ "English MCQs",
+ "Biology MCQs",
+ "Chemistry MCQs",
+ "History MCQs",
+ "Math I MCQs",
+ "Math II MCQs",
+ "Political Science MCQs",
+ ],
+ "multi_choice_accuracy": [
+ "Chinese Lang and Usage MCQs",
+ "Chinese Modern Lit",
+ "English Fill in Blanks",
+ "English Reading Comp",
+ "Geography MCQs",
+ "Physics MCQs",
+ "English Cloze Test",
+ ],
+ "math_equivalence": ["Math I Fill-in-the-Blank", "Math II Fill-in-the-Blank"],
+ "rouge_score": ["English Language Cloze Passage"],
+ "rouge_zh_score": [
+ "Chinese Language Famous Passages and Sentences Dictation",
+ "Chemistry Open-ended Questions",
+ "History Open-ended Questions",
+ "Biology Open-ended Questions",
+ "Political Science Open-ended Questions",
+ "English Language Error Correction",
+ "Chinese Language Language and Writing Skills Open-ended Questions",
+ "Math II Open-ended Questions",
+ "Chinese Language Literary Text Reading",
+ "Chinese Language Ancient Poetry Reading",
+ "Chinese Language Classical Chinese Reading",
+ "Physics Open-ended Questions",
+ "Math I Open-ended Questions",
+ "Geography Open-ended Questions",
+ "Chinese Language Practical Text Reading",
+ ],
+ "perplexity": ["ALL"],
+ "ppl_score_over_choices": ["ALL"],
+ "ppl_score": ["ALL"],
+ },
+ "longbench": {
+ "f1_score": ["hotpotqa", "2wikimqa", "musique", "narrativeqa", "qasper", "multifieldqa_en", "triviaqa"],
+ "f1_zh_score": ["multifieldqa_zh"],
+ "rouge_score": ["gov_report", "qmsum", "multi_news", "samsum"],
+ "rouge_zh_score": ["dureader", "vcsum"],
+ "retrieval_score": ["passage_retrieval_en"],
+ "retrieval_zh_score": ["passage_retrieval_zh"],
+ "classification_score": ["trec", "lsht"],
+ "code_sim_score": ["lcc", "repobench-p"],
+ "count_score": ["passage_count"],
+ "perplexity": ["ALL"],
+ "ppl_score": ["ALL"],
+ },
+ "mmlu": {
+ "first_token_accuracy": ["ALL"],
+ "single_choice_accuracy": ["ALL"],
+ "accuracy": ["ALL"],
+ "perplexity": ["ALL"],
+ "ppl_score_over_choices": ["ALL"],
+ "ppl_score": ["ALL"],
+ },
+}
+
+
+def _fix_fracs(string):
+ substrs = string.split("\\frac")
+ new_str = substrs[0]
+ if len(substrs) > 1:
+ substrs = substrs[1:]
+ for substr in substrs:
+ new_str += "\\frac"
+ if substr[0] == "{":
+ new_str += substr
+ else:
+ try:
+ assert len(substr) >= 2
+ except:
+ return string
+ a = substr[0]
+ b = substr[1]
+ if b != "{":
+ if len(substr) > 2:
+ post_substr = substr[2:]
+ new_str += "{" + a + "}{" + b + "}" + post_substr
+ else:
+ new_str += "{" + a + "}{" + b + "}"
+ else:
+ if len(substr) > 2:
+ post_substr = substr[2:]
+ new_str += "{" + a + "}" + b + post_substr
+ else:
+ new_str += "{" + a + "}" + b
+ string = new_str
+ return string
+
+
+def _fix_a_slash_b(string):
+ if len(string.split("/")) != 2:
+ return string
+ a = string.split("/")[0]
+ b = string.split("/")[1]
+ try:
+ a = int(a)
+ b = int(b)
+ assert string == "{}/{}".format(a, b)
+ new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
+ return new_string
+ except:
+ return string
+
+
+def _remove_right_units(string):
+ # "\\text{ " only ever occurs (at least in the val set) when describing units
+ if "\\text{ " in string:
+ splits = string.split("\\text{ ")
+ assert len(splits) == 2
+ return splits[0]
+ else:
+ return string
+
+
+def _fix_sqrt(string):
+ if "\\sqrt" not in string:
+ return string
+ splits = string.split("\\sqrt")
+ new_string = splits[0]
+ for split in splits[1:]:
+ if split[0] != "{":
+ a = split[0]
+ new_substr = "\\sqrt{" + a + "}" + split[1:]
+ else:
+ new_substr = "\\sqrt" + split
+ new_string += new_substr
+ return new_string
+
+
+def _strip_string(string):
+ # linebreaks
+ string = string.replace("\n", "")
+ # print(string)
+
+ # remove inverse spaces
+ string = string.replace("\\!", "")
+ # print(string)
+
+ # replace \\ with \
+ string = string.replace("\\\\", "\\")
+ # print(string)
+
+ # replace tfrac and dfrac with frac
+ string = string.replace("tfrac", "frac")
+ string = string.replace("dfrac", "frac")
+ # print(string)
+
+ # remove \left and \right
+ string = string.replace("\\left", "")
+ string = string.replace("\\right", "")
+ # print(string)
+
+ # Remove circ (degrees)
+ string = string.replace("^{\\circ}", "")
+ string = string.replace("^\\circ", "")
+
+ # remove dollar signs
+ string = string.replace("\\$", "")
+
+ # remove units (on the right)
+ string = _remove_right_units(string)
+
+ # remove percentage
+ string = string.replace("\\%", "")
+ string = string.replace("\%", "")
+
+ # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
+ string = string.replace(" .", " 0.")
+ string = string.replace("{.", "{0.")
+ # if empty, return empty string
+ if len(string) == 0:
+ return string
+ if string[0] == ".":
+ string = "0" + string
+
+ # to consider: get rid of e.g. "k = " or "q = " at beginning
+ if len(string.split("=")) == 2:
+ if len(string.split("=")[0]) <= 2:
+ string = string.split("=")[1]
+
+ # fix sqrt3 --> sqrt{3}
+ string = _fix_sqrt(string)
+
+ # remove spaces
+ string = string.replace(" ", "")
+
+ # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
+ string = _fix_fracs(string)
+
+ # manually change 0.5 --> \frac{1}{2}
+ if string == "0.5":
+ string = "\\frac{1}{2}"
+
+ # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
+ string = _fix_a_slash_b(string)
+
+ return string
+
+
+def parse_math_answer(raw_string):
+ def remove_boxed(s):
+ left = "\\boxed{"
+ try:
+ assert s[: len(left)] == left
+ assert s[-1] == "}"
+ answer = s[len(left) : -1]
+ if "=" in answer:
+ answer = answer.split("=")[-1].lstrip(" ")
+ return answer
+ except:
+ return None
+
+ def last_boxed_only_string(string):
+ idx = string.rfind("\\boxed")
+ if idx < 0:
+ idx = string.rfind("\\fbox")
+ if idx < 0:
+ return None
+ i = idx
+ right_brace_idx = None
+ num_left_braces_open = 0
+ while i < len(string):
+ if string[i] == "{":
+ num_left_braces_open += 1
+ if string[i] == "}":
+ num_left_braces_open -= 1
+ if num_left_braces_open == 0:
+ right_brace_idx = i
+ break
+ i += 1
+
+ if right_brace_idx == None:
+ retval = None
+ else:
+ retval = string[idx : right_brace_idx + 1]
+
+ return retval
+
+ def get_answer_with_dollar_sign(s):
+ first_pattern = "\$(.*)\$"
+ last_match = None
+ matches = re.findall(first_pattern, s)
+ if matches:
+ last_match = matches[-1]
+ if "=" in last_match:
+ last_match = last_match.split("=")[-1].lstrip(" ")
+ return last_match
+
+ def get_answer_without_dollar_sign(s):
+ last_match = None
+ if "=" in s:
+ last_match = s.split("=")[-1].lstrip(" ").rstrip(".")
+ if "\\n" in last_match:
+ last_match = last_match.split("\\n")[0]
+ else:
+ pattern = "(?:\\$)?\d+(?:\.\d+)?(?![\w\d])"
+ matches = re.findall(pattern, s)
+ if matches:
+ last_match = matches[-1]
+ return last_match
+
+ if "\\boxed" in raw_string:
+ answer = remove_boxed(last_boxed_only_string(raw_string))
+ else:
+ answer = get_answer_with_dollar_sign(raw_string)
+ if not answer:
+ answer = get_answer_without_dollar_sign(raw_string)
+ return answer
+
+
+def math_equivalence(prediction, reference, **kwargs):
+ prediction = parse_math_answer(prediction)
+
+ if prediction is None and reference is None:
+ print("WARNING: Both None")
+ return False
+
+ if prediction is None or reference is None:
+ return False
+
+ try:
+ ss1 = _strip_string(prediction)
+ ss2 = _strip_string(reference)
+ return ss1 == ss2
+ except:
+ return prediction == reference
+
+
+def multi_choice_accuracy(prediction, reference, **kwargs):
+ # Only find uppercase letters not surrounded by lowercase letters
+ all_classes = kwargs.get("all_classes", None)
+ if all_classes:
+ pattern = f"(? highest_similarity:
+ highest_similarity = similarity
+ best_match = string
+ score = float(best_match == reference)
+ return score
+
+
+def rouge_score(prediction, reference, **kwargs):
+ rouge = Rouge()
+ try:
+ scores = rouge.get_scores([prediction], [reference], avg=True)
+ except:
+ return 0.0
+ return scores["rouge-l"]["f"]
+
+
+def rouge_zh_score(prediction, reference, **kwargs):
+ prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
+ reference = " ".join(list(jieba.cut(reference, cut_all=False)))
+ score = rouge_score(prediction, reference)
+ return score
+
+
+def _f1_score(prediction, reference, **kwargs):
+ common = Counter(prediction) & Counter(reference)
+ num_same = sum(common.values())
+ if num_same == 0:
+ return 0
+ precision = 1.0 * num_same / len(prediction)
+ recall = 1.0 * num_same / len(reference)
+ f1 = (2 * precision * recall) / (precision + recall)
+ return f1
+
+
+def f1_score(prediction, reference, **kwargs):
+ normalized_prediction = normalize_answer(prediction)
+ normalized_ground_truth = normalize_answer(reference)
+
+ prediction_tokens = normalized_prediction.split()
+ ground_truth_tokens = normalized_ground_truth.split()
+ return _f1_score(prediction_tokens, ground_truth_tokens)
+
+
+def f1_zh_score(prediction, reference, **kwargs):
+ prediction_tokens = list(jieba.cut(prediction, cut_all=False))
+ ground_truth_tokens = list(jieba.cut(reference, cut_all=False))
+ prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens]
+ ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens]
+ prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
+ ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
+ return _f1_score(prediction_tokens, ground_truth_tokens)
diff --git a/applications/ColossalEval/colossal_eval/evaluate/evaluator.py b/applications/ColossalEval/colossal_eval/evaluate/evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..11e204b504c5be370285bef9c5ea2a089e403376
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/evaluate/evaluator.py
@@ -0,0 +1,110 @@
+import os
+from typing import Any, Dict, List
+
+import colossal_eval.evaluate.gpt_evaluate as gpt_evaluate
+
+from .utils import get_data_per_category
+
+
+class Evaluator(object):
+ """
+ A class named Evaluator includes GPT-3.5/GPT-4 evaluation
+
+ """
+
+ def __init__(
+ self,
+ params: Dict[str, Any],
+ battle_prompt: Dict[str, Any],
+ gpt_evaluation_prompt: Dict[str, Any],
+ gpt_model: str,
+ language: str,
+ gpt_with_reference: bool,
+ ) -> None:
+ self.params = params
+ self.battle_prompt = battle_prompt
+ self.gpt_evaluation_prompt = gpt_evaluation_prompt
+ self.gpt_model = gpt_model
+ self.language = language
+ self.gpt_with_reference = gpt_with_reference
+ self.gpt_evaluation_results = dict()
+ self.battle_results = []
+
+ def battle(self, answers1: List[Dict], answers2: List[Dict]) -> None:
+ """
+ Comparison between two models using GPT-4 as the reviewer.
+ """
+
+ self.battle_results = gpt_evaluate.battle(answers1, answers2, self.battle_prompt)
+
+ def evaluate(self, answers: List[Dict], targets: List[Dict], save_path: str, model_name: str) -> None:
+ """
+ A comprehensive evaluation of the answers from the model.
+ The function evaluates the model's performance from different perspectives
+ using GPT-3.5, GPT-4, and off-the-shelf evaluation metrics.
+
+ The metrics will be decided by the config file.
+
+ """
+
+ answers_per_category = get_data_per_category(answers, list(self.params.keys()))
+ targets_per_category = get_data_per_category(targets, list(self.params.keys()))
+
+ # gpt evaluation
+ for category in self.params:
+ if len(answers_per_category[category]) == 0:
+ print(f"Category {category} specified in your config doesn't have corresponding answers!")
+ continue
+
+ if self.params[category].get("GPT", None) is None:
+ continue
+
+ category_metrics = self.params[category]["GPT"]
+
+ prompt = self.gpt_evaluation_prompt.get(category, None)
+ if prompt is None:
+ print(f"No prompt for category {category}! Use prompt for category general now.")
+ prompt = self.gpt_evaluation_prompt["general"]
+
+ self.gpt_evaluation_results[category] = gpt_evaluate.evaluate(
+ answers_per_category[category],
+ prompt,
+ category_metrics,
+ category,
+ save_path,
+ model_name,
+ self.gpt_model,
+ self.language,
+ references=targets_per_category[category] if self.gpt_with_reference else None,
+ )
+
+ def save(self, path: str, model_name_list: List[str]) -> None:
+ """
+ Save evaluation results of GPT-3.5, GPT-4, and off-the-shelf evaluation metrics.
+
+ """
+
+ if len(model_name_list) == 2:
+ save_path = os.path.join(path, "gpt_evaluate", "battle_results")
+ gpt_evaluate.save_battle_results(self.battle_results, model_name_list[0], model_name_list[1], save_path)
+ else:
+ if self.gpt_evaluation_results:
+ # Save evaluation results for GPT evaluation metrics.
+ gpt_base_save_path = os.path.join(path, "gpt_evaluate", "gpt_evaluate_results")
+ gpt_evaluation_results_save_path = os.path.join(gpt_base_save_path, "evaluation_results")
+
+ all_evaluations = gpt_evaluate.save_gpt_evaluation_results(
+ model_name_list[0], self.gpt_evaluation_results, gpt_evaluation_results_save_path
+ )
+
+ # Start to calculate scores and save statistics.
+ gpt_evaluation_statistics_save_path = os.path.join(gpt_base_save_path, "evaluation_statistics")
+ gpt_evaluate.save_gpt_evaluation_statistics(
+ model_name_list[0], all_evaluations, gpt_evaluation_statistics_save_path
+ )
+
+ # Save charts and csv.
+ gpt_evaluation_analyses_save_path = os.path.join(gpt_base_save_path, "evaluation_analyses")
+ gpt_evaluate.analyze_gpt_evaluation_statistics(
+ gpt_evaluation_statistics_save_path, gpt_evaluation_analyses_save_path
+ )
diff --git a/applications/ColossalEval/colossal_eval/evaluate/gpt_evaluate.py b/applications/ColossalEval/colossal_eval/evaluate/gpt_evaluate.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b1ed1143f037f22bd6a3434a9985c76e930f79
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/evaluate/gpt_evaluate.py
@@ -0,0 +1,852 @@
+import concurrent.futures
+import os
+import re
+import time
+from copy import deepcopy
+from typing import Any, Dict, List
+
+import matplotlib.pyplot as plt
+import numpy as np
+import openai
+import pandas as pd
+import seaborn as sns
+import tqdm
+from colossal_eval.utils import jdump, jload
+
+ref_step_template = {
+ "en": "Now please compare the answer with the {adjective} answer, determine whether the answer is able to achieve the same level of {metric}.\n\n",
+ "cn": "请比较答案与上面的{adjective}答案,确定答案是否可以达到与该{adjective}答案同样水平的{metric}。\n\n",
+}
+
+ref_answer_template_general = {
+ "en": "\nAn example answer with good quality is as follows:\n\n{answer}\n\n",
+ "cn": "\n一个优质的示例答案如下:\n\n{answer}\n\n",
+}
+
+ref_answer_template_correctness = {
+ "en": "\nA correct answer is as follows:\n\n{answer}\n\n",
+ "cn": "\n标准答案如下:\n\n{answer}\n\n",
+}
+
+
+def get_battle_result(sys_prompt: str, user_prompt: str, id: int, max_tokens: int = 2048) -> Dict[str, Any]:
+ """
+ Get battle evaluation from GPT-4.
+
+ Args:
+ sys_prompt: prompt for the system.
+ user_prompt: prompt for the user.
+ id: id of the answers for comparison.
+ max_tokens: the maximum number of tokens to generate in the chat completion.
+
+ Returns:
+ An evaluation of one comparison.
+ """
+
+ MAX_API_RETRY = 3
+ for _ in range(MAX_API_RETRY):
+ try:
+ response = openai.ChatCompletion.create(
+ model="gpt-4",
+ messages=[
+ {"role": "system", "content": sys_prompt},
+ {
+ "role": "user",
+ "content": user_prompt,
+ },
+ ],
+ temperature=0.2,
+ max_tokens=max_tokens,
+ )
+ evaluation = response["choices"][0]["message"]["content"]
+ return {"evaluation": evaluation, "id": id}
+ except Exception as e:
+ print(e)
+ time.sleep(1)
+ print(f"Evaluation {id} failed after {MAX_API_RETRY} retries.")
+ return {"evaluation": "", "id": id}
+
+
+def parse_battle_score(evaluation: str) -> List[float]:
+ """
+ Parse evaluation from GPT-4 and get the scores of model 1 and 2.
+
+ Args:
+ evaluation: evaluation from GPT-4.
+
+ Returns:
+ A score pair of two different model answers.
+ """
+
+ try:
+ pattern = re.compile("([0-9]|10) out of 10")
+ sp = re.findall(pattern, evaluation)
+ if len(re.findall(pattern, evaluation)) == 2:
+ return [float(sp[0]), float(sp[1])]
+
+ pattern = re.compile("a score of ([0-9]|10)")
+ sp = re.findall(pattern, evaluation)
+ if len(re.findall(pattern, evaluation)) == 2:
+ return [float(sp[0]), float(sp[1])]
+
+ pattern = re.compile("([0-9]|10)/10")
+ sp = re.findall(pattern, evaluation)
+ if len(re.findall(pattern, evaluation)) == 2:
+ return [float(sp[0]), float(sp[1])]
+
+ score_pair = evaluation.split("\n")[0]
+ score_pair = score_pair.replace(",", " ")
+ sp = score_pair.split(" ")
+ if len(sp) == 2:
+ return [float(sp[0]), float(sp[1])]
+ else:
+ raise Exception(f"Invalid score pair. Got {evaluation}.")
+ except Exception:
+ return [-1, -1]
+
+
+def battle(answer1: List[Dict], answer2: List[Dict], prompt_dict: Dict[str, Any]) -> List[Dict]:
+ """
+ Use GPT-4 to compare answers of two different models.
+
+ Args:
+ answer1: answers of model 1.
+ answer2: answers of model 2.
+ prompt_dict: prompt for battle.
+
+ Returns:
+ Evaluations of all comparison pairs.
+ """
+
+ assert len(answer1) == len(answer2)
+
+ total_len = len(answer1)
+ question_idx_list = list(range(total_len))
+
+ print(f" Total number of answers: {len(answer1)}.")
+
+ evaluations = []
+ with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
+ futures = []
+ for i in question_idx_list:
+ assert answer1[i]["id"] == answer2[i]["id"]
+ answer_id = answer1[i]["id"]
+
+ ques = (
+ answer1[i]["instruction"]
+ if answer1[i]["input"] == ""
+ else answer1[i]["instruction"] + " " + answer1[i]["input"]
+ )
+ answer1[i]["category"]
+ ans1 = answer1[i]["output"]
+ ans2 = answer2[i]["output"]
+
+ sys_prompt = prompt_dict["system_prompt"]
+ prompt_template = prompt_dict["prompt_template"]
+ prompt = prompt_template.format(
+ question=ques,
+ answer_1=ans1,
+ answer_2=ans2,
+ prompt=prompt_dict["prompt"],
+ )
+
+ future = executor.submit(get_battle_result, sys_prompt, prompt, answer_id, 2048)
+ futures.append(future)
+
+ for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
+ evaluations.append(future.result())
+
+ evaluations.sort(key=lambda x: x["id"])
+
+ return evaluations
+
+
+def save_battle_results(evaluations: List[Dict], name1: str, name2: str, save_path: str) -> None:
+ """
+ Save evaluation results (model 1 vs model 2) from GPT-4.
+
+ Args:
+ evaluations: evaluation results from GPT-4.
+ name1: model 1 's name.
+ name2: model 2 's name.
+ save_path: path to save battle results.
+ """
+
+ evaluation_file = deepcopy(evaluations)
+
+ ans1_score = 0
+ ans2_score = 0
+ better_count = 0
+ worse_count = 0
+ tie_count = 0
+ invalid_count = 0
+
+ better_file = []
+ worse_file = []
+ tie_file = []
+ invalid_file = []
+
+ for idx, evaluation in enumerate(evaluations):
+ scores = parse_battle_score(evaluation["evaluation"])
+ evaluation_file[idx]["score"] = scores
+
+ if scores[0] == -1 and scores[1] == -1:
+ invalid_count += 1
+ invalid_file.append(evaluation_file[idx])
+ print(f'Invalid score pair: {evaluation_file[idx]["id"]}.')
+ else:
+ if scores[0] > scores[1]:
+ worse_count += 1
+ worse_file.append(evaluation_file[idx])
+ elif scores[0] < scores[1]:
+ better_count += 1
+ better_file.append(evaluation_file[idx])
+ else:
+ tie_count += 1
+ tie_file.append(evaluation_file[idx])
+ ans1_score += scores[0]
+ ans2_score += scores[1]
+
+ prefix = f"{name1}_vs_{name2}"
+
+ if not os.path.exists(save_path):
+ os.makedirs(save_path)
+
+ jdump(better_file, os.path.join(save_path, prefix, f"{name2}_better.json"))
+ jdump(worse_file, os.path.join(save_path, prefix, f"{name2}_worse.json"))
+ jdump(tie_file, os.path.join(save_path, prefix, f"{prefix}_tie.json"))
+ jdump(invalid_file, os.path.join(save_path, prefix, f"{prefix}_invalid.json"))
+ jdump(evaluation_file, os.path.join(save_path, prefix, f"{prefix}_evaluations.json"))
+
+ if os.path.exists(os.path.join(save_path, "battle_results.json")):
+ results = jload(os.path.join(save_path, "battle_results.json"))
+ else:
+ results = {}
+
+ results[prefix] = {
+ "model": [name1, name2],
+ "better": better_count,
+ "worse": worse_count,
+ "tie": tie_count,
+ "win_rate": better_count / (len(evaluations) - invalid_count),
+ "score": [
+ ans1_score / (len(evaluations) - invalid_count),
+ ans2_score / (len(evaluations) - invalid_count),
+ ],
+ }
+ jdump(results, os.path.join(save_path, "battle_results.json"))
+
+ print(f"Total {invalid_count} invalid score pair(s).")
+ print(f"Model {name2} has {better_count} better answer(s).")
+ print(f"Model {name2} has {worse_count} worse answer(s).")
+ print(f"{tie_count} answer(s) play(s) to a tie.")
+ print(f"Win rate of model {name2}: {better_count/(len(evaluations)-invalid_count):.2f}")
+ print(f"Model {name1} average score: {ans1_score/(len(evaluations)-invalid_count):.2f}")
+ print(f"Model {name2} average score: {ans2_score/(len(evaluations)-invalid_count):.2f}")
+
+
+def reference_template(metric: str, language: str, reference: Dict[str, Any]) -> str:
+ """
+ Get prompt template for GPT evaluation with reference.
+
+ Different languages have different prompt templates.
+
+ Args:
+ metric: metric used in GPT evaluation with reference.
+ language: language for the template.
+ reference: the instruction that contains target answer.
+
+ Returns:
+ Prompt template for GPT evaluation with reference.
+ """
+
+ step_to_add = ref_step_template[language]
+
+ for_the_given_answer = (
+ "{metric} (1-5) (directly give the score for the given answer):"
+ if language == "en"
+ else "{metric} (1-5) (直接对给定答案打分)"
+ )
+
+ # adjective is used to describe the word "answer" in the prompt.
+ adjective = "example" if language == "en" else "示例"
+ answer_to_add = ref_answer_template_general[language]
+
+ # Only for correctness, we will provide a correct answer and so the adjective for "answer" will be "correct". The prompt words will be "a correct answer".
+ # In other cases, the prompt words will be "an example answer with good quality" by default.
+ if metric.lower() == "correctness":
+ adjective = "correct" if language == "en" else "标准"
+ answer_to_add = ref_answer_template_correctness[language]
+
+ answer_to_add = answer_to_add.format(answer=reference["target"] if reference["target"] else reference["output"])
+ step_to_add = step_to_add.format(metric=metric.lower(), adjective=adjective) + for_the_given_answer.format(
+ metric=metric
+ )
+
+ return answer_to_add + step_to_add
+
+
+def fill_in_message(role: str, content: str) -> Dict[str, str]:
+ """
+ Generate one formatted message to send through chat completion.
+
+ Args:
+ role: the role of the author of this message.
+ content: the contents of the message.
+
+ Returns:
+ One message to send through chat completion.
+ """
+
+ return {"role": role, "content": content}
+
+
+def multiturn_chat_completion(user_messages: List[str], model: str, max_tokens: int = 1, turns=2) -> Dict[str, Any]:
+ """
+ Do multi-turn chat completion.
+
+ When turns == 1, it is a one-turn conversation for normal GPT evaluation.
+ When turns == 2, it is a two-turn conversation which is used for GPT evaluation with reference answers.
+
+ Args:
+ user_messages: messages user wants to send.
+ model: the model used to evaluate answers.
+ max_tokens: the maximum number of tokens to generate in the chat completion.
+ turns: the number of turns for conversation.
+
+ Returns:
+ Last turn's response.
+ """
+
+ if len(user_messages) != turns:
+ raise Exception("The length of user messages should be equal to the turn number!")
+
+ assistant_responses = []
+
+ for i in range(turns):
+ messages_to_send = []
+
+ for j in range(i):
+ messages_to_send.append(fill_in_message("user", user_messages[j]))
+ messages_to_send.append(
+ fill_in_message("assistant", assistant_responses[j]["choices"][0]["message"]["content"])
+ )
+
+ # Length of user messages == Length of assistant messages + 1
+ # Because we always expect the api to response
+ messages_to_send.append(fill_in_message("user", user_messages[i]))
+
+ response = openai.ChatCompletion.create(
+ model=model,
+ messages=messages_to_send,
+ temperature=0,
+ max_tokens=max_tokens,
+ )
+
+ # Avoid exceeding rate limits.
+ # You can comment this line if your request doesn't contain many tokens.
+ time.sleep(1)
+
+ assistant_responses.append(response)
+
+ return assistant_responses[-1]
+
+
+def get_gpt_evaluation_without_logprobs(
+ prompt: Dict[str, Any],
+ inst: Dict[str, Any],
+ metrics: List[str],
+ language: str,
+ reference: Dict[str, Any] = None,
+ model: str = "gpt-3.5-turbo",
+ max_tokens: int = 2048,
+) -> Dict[str, Any]:
+ """
+ Use chat models(gpt-3.5-turbo or gpt-4) to evaluate one model answer.
+
+ Temprature is set to 0 to make the model more deterministic.
+
+ Args:
+ prompt: a dictionary including prompt template, CoT and metrics.
+ inst: the instruction that is needed to be evaluated.
+ metrics: the metrics for evaluation.
+ language: language used to change the CoT(add one more step about comparing the given answer and reference) if reference is not None.
+ reference: the reference answer.
+ model: the model used to evaluate answers.
+ max_tokens: the maximum number of tokens to generate in the chat completion.
+
+ Returns:
+ An evaluation of one answer.
+ """
+
+ MAX_API_RETRY = 3
+
+ question = inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"]
+ answer = inst["output"]
+ inst["evaluation"] = {}
+
+ for metric in metrics:
+ if prompt["metrics"].get(metric, None) is None:
+ raise Exception(
+ f"Unsupported metric {metric} for category {inst['category']}! You should add this metric in the prompt file!"
+ )
+ for i in range(MAX_API_RETRY):
+ try:
+ prompt_reference = "" if reference is None else reference_template(metric, language, reference)
+
+ prompt_1st_round = prompt["prompt"].format(
+ question=question,
+ answer=answer,
+ metric=prompt["metrics"][metric],
+ steps=prompt["CoT"][metric],
+ )
+
+ if prompt_reference and (reference["target"] or reference["output"]):
+ # Do a 2-round conversation
+ response = multiturn_chat_completion(
+ [prompt_1st_round, prompt_reference], model, max_tokens=max_tokens, turns=2
+ )
+ else:
+ response = multiturn_chat_completion([prompt_1st_round], model, max_tokens=max_tokens, turns=1)
+
+ inst["evaluation"][metric] = {
+ "response": response["choices"][0]["message"]["content"],
+ "logprobs": None,
+ }
+
+ # Prevent exceeding rate limits because we have multiple workers.
+ # But this will slow down the evaluation process.
+ # You can comment this line if your request doesn't contain many tokens.
+ time.sleep(len(metrics) * 0.5)
+
+ break
+ except Exception as e:
+ print(e)
+ time.sleep(1)
+ if metric not in inst["evaluation"]:
+ print(f"Evaluation {inst['id']} for metric {metric} failed after {MAX_API_RETRY} retries.")
+ inst["evaluation"][metric] = {}
+ return inst
+
+
+def get_gpt_evaluation_with_logprobs(
+ prompt: Dict[str, Any], inst: Dict[str, Any], metrics: List[str], max_tokens: int = 2048
+) -> Dict[str, Any]:
+ """
+ Use completion model(text-davinci-003) to evaluate one model answer.
+ Only completion models can return log probabilities.
+
+ Temprature is set to 0 to make the model more deterministic.
+
+ Args:
+ prompt: a dictionary including prompt template, CoT and metrics.
+ inst: the instruction that is needed to be evaluated.
+ metrics: the metrics for evaluation.
+ max_tokens: the maximum number of tokens to generate in the completion.
+
+ Returns:
+ An evaluation of one answer.
+ """
+
+ MAX_API_RETRY = 3
+
+ question = inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"]
+ answer = inst["output"]
+ inst["evaluation"] = {}
+
+ for metric in metrics:
+ if prompt["metrics"].get(metric, None) is None:
+ raise Exception(
+ f"Unsupported metric {metric} for category {inst['category']}! You should add this metric in the prompt file!"
+ )
+ for i in range(MAX_API_RETRY):
+ try:
+ response = openai.Completion.create(
+ model="text-davinci-003",
+ prompt=prompt["prompt"].format(
+ question=question,
+ answer=answer,
+ metric=prompt["metrics"][metric],
+ steps=prompt["CoT"][metric],
+ ),
+ logprobs=5,
+ temperature=0,
+ max_tokens=max_tokens,
+ )
+ inst["evaluation"][metric] = {
+ "response": response["choices"][0]["text"],
+ "logprobs": response["choices"][0]["logprobs"]["top_logprobs"],
+ }
+
+ # Prevent exceeding rate limits because we have multiple workers.
+ # But this will slow down the evaluation process.
+ # You can comment this line if your request doesn't contain many tokens.
+ time.sleep(len(metrics) * 0.5)
+
+ break
+ except Exception as e:
+ print(e)
+ time.sleep(1)
+ if metric not in inst["evaluation"]:
+ print(f"Evaluation {inst['id']} for metric {metric} failed after {MAX_API_RETRY} retries.")
+ inst["evaluation"][metric] = {}
+ return inst
+
+
+def evaluate(
+ answers: List[Dict],
+ prompt: Dict[str, Any],
+ metrics: List[str],
+ category: str,
+ save_path: str,
+ model_name: str,
+ model: str,
+ language: str,
+ references: List[Dict] = None,
+) -> List[Dict]:
+ """
+ Use GPT models to evaluate model answers and save evaluation results.
+
+ Args:
+ answers: model answers.
+ prompt: prompt for GPT evaluation.
+ metrics: metrics for GPT evaluation.
+ category: the category of the model answers for evaluation.
+ model: the specific GPT model used to evaluate answers.
+ language: language used in GPT evaluation
+ references: references for GPT evaluation
+
+ Returns:
+ Evaluations of the given answers.
+ """
+
+ print(f"The number of instances of category {category}'s is {len(answers)}.")
+
+ evaluations = []
+
+ metrics_str = ", ".join(x for x in metrics)
+ print(f"Category {category}'s metrics are {metrics_str}.")
+
+ gpt_base_save_path = os.path.join(save_path, "gpt_evaluate", "gpt_evaluate_results")
+ gpt_evaluation_results_save_path = os.path.join(gpt_base_save_path, "evaluation_results")
+ category_file = os.path.join(gpt_evaluation_results_save_path, model_name, f"{category}_evaluation_results.json")
+
+ if os.path.exists(category_file):
+ print(f"Evaluation results for category {category}, model {model_name} already exists.")
+ print("Skip evaluating.")
+
+ evaluations = jload(category_file)
+
+ retry = []
+ evaluations_copy = deepcopy(evaluations)
+
+ success = []
+ for idx, e in enumerate(evaluations_copy):
+ keys = list(e["evaluation"].keys())
+ for key in keys:
+ if e["evaluation"][key] == {}:
+ retry.append(e["id"])
+ print(f"Re-evaluate id {e['id']} now.")
+ break
+ if e["id"] not in retry:
+ success.append(e)
+
+ if len(retry) == 0:
+ evaluations.sort(key=lambda x: x["id"])
+ print(f"{category} done.")
+ return evaluations
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
+ futures = []
+ for idx, inst in enumerate(answers):
+ if not inst["id"] in retry:
+ continue
+ # Completion models can return log probabilities.
+ if model == "text-davinci-003":
+ future = executor.submit(get_gpt_evaluation_with_logprobs, prompt, inst, metrics, 1)
+ else:
+ future = executor.submit(
+ get_gpt_evaluation_without_logprobs,
+ prompt,
+ inst,
+ metrics,
+ language,
+ reference=None if references is None else references[idx],
+ model=model,
+ max_tokens=1,
+ )
+
+ futures.append(future)
+
+ for future in tqdm.tqdm(
+ concurrent.futures.as_completed(futures),
+ desc=f"{category}: ",
+ total=len(futures),
+ ):
+ success.append(future.result())
+
+ success.sort(key=lambda x: x["id"])
+
+ print(f"Saving evaluation results for category {category}, model {model_name}.")
+
+ jdump(success, category_file)
+
+ return success
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
+ futures = []
+ for idx, inst in enumerate(answers):
+ # Completion models can return log probabilities.
+ if model == "text-davinci-003":
+ future = executor.submit(get_gpt_evaluation_with_logprobs, prompt, inst, metrics, 1)
+ else:
+ future = executor.submit(
+ get_gpt_evaluation_without_logprobs,
+ prompt,
+ inst,
+ metrics,
+ language,
+ reference=None if references is None else references[idx],
+ model=model,
+ max_tokens=1,
+ )
+
+ futures.append(future)
+
+ for future in tqdm.tqdm(
+ concurrent.futures.as_completed(futures),
+ desc=f"{category}: ",
+ total=len(futures),
+ ):
+ evaluations.append(future.result())
+
+ evaluations.sort(key=lambda x: x["id"])
+
+ print(f"{category} done.")
+
+ print(f"Saving evaluation results for category {category}, model {model_name}.")
+
+ jdump(evaluations, category_file)
+
+ return evaluations
+
+
+def calculate_scores_form_logprobs(logprobs: Dict[str, Any]) -> float:
+ """
+ Calculate the score according to log probabilities returned by text-davinci-003.
+
+ Calculation formula:
+ score = sum(score_i * exp(value)) where score_i is the score which corresponds to the key(predicted token) and value is its log probability.
+
+ Ref: https://arxiv.org/abs/2303.16634
+ This paper proposes NLG evaluation methods using text-davinci-003(log probabilities returned by completion models) and GPT-4(probabilities obtained by sampling).
+
+ Args:
+ logprobs: logprobs returned by openai.Completion.
+
+ Returns:
+ The score of one answer.
+ """
+
+ # GPT-3.5 only returns score of 1 to 5.
+ prob = np.zeros(5)
+
+ for key, value in logprobs.items():
+ # Sometimes the key will be one byte of a unicode character which takes the form of "bytes:\\xe7".
+ # It is meaningless and thus we don't calculate probability.
+ if "bytes" in key:
+ continue
+ # results[0] is the score which corresponds to the key(predicted token).
+ # For example, key "5" corresponds to score 5.
+ results = re.findall(r"\d", key)
+ if len(results) == 1:
+ prob[int(results[0]) - 1] = prob[int(results[0]) - 1] + np.exp(value)
+
+ score = np.dot(np.arange(1, 6), prob)
+
+ return score
+
+
+def calculate_scores_form_response(response: str, evaluation: Dict[str, Any]) -> int:
+ """
+ Calculate the score from the response returned by gpt-3.5-turbo or gpt-4.
+ Different from text-davinci-003, this fuction directly calculates the score according to the plain response returned by gpt-3.5-turbo or gpt-4.
+ Although text-davinci-003 can return log probabilities, it costs ten times as much as gpt-3.5-turbo.
+
+ Args:
+ response: logprobs returned by openai.Completion.
+ evaluation: the evaluation corresponds to the question.
+
+ Returns:
+ The score of one answer.
+ """
+
+ try:
+ results = re.findall(r"\d", response)
+ if len(results) == 1:
+ return int(results[0])
+ else:
+ raise Exception(f"Invalid score pair. Got {evaluation}.")
+ except Exception:
+ return 0
+
+
+def save_gpt_evaluation_results(
+ model_name: str, gpt_evaluation_results: Dict[str, Any], save_path: str
+) -> Dict[str, Any]:
+ """
+ Save evaluation results for different categories for one model.
+
+ Args:
+ model_name: name of the model for saving evaluation results.
+ gpt_evaluation_results: evaluations results for all of the model answers.
+ save_path: path to save GPT evaluation statistics.
+ """
+
+ all_evaluations = []
+ for category, evaluations in gpt_evaluation_results.items():
+ jdump(evaluations, os.path.join(save_path, model_name, f"{category}_evaluation_results.json"))
+ all_evaluations.extend(evaluations)
+
+ jdump(all_evaluations, os.path.join(save_path, f"{model_name}_evaluation_results.json"))
+
+ return all_evaluations
+
+
+def save_gpt_evaluation_statistics(model_name: str, evaluations: List[Dict], save_path: str) -> None:
+ """
+ Generate statistics for one model.
+
+ Args:
+ model_name: name of the model for saving statistics.
+ evaluations: evaluations for all of the model answers.
+ save_path: path to save GPT evaluation statistics.
+ """
+
+ if not os.path.exists(save_path):
+ os.makedirs(save_path)
+
+ data_per_category = {}
+ for evaluation in evaluations:
+ category = evaluation["category"]
+ if evaluation["category"] in data_per_category.keys():
+ data_per_category[category].append(evaluation)
+ else:
+ data_per_category[category] = [evaluation]
+
+ all_statistics = {}
+ for category, data in data_per_category.items():
+ metrics = data[0]["evaluation"].keys()
+ scores = {metric: [] for metric in metrics}
+ for evaluation in data:
+ for metric in metrics:
+ if evaluation["evaluation"][metric] == {}:
+ # This means after 3 retries, the server still returns an error and we set the score to 0.
+ scores[metric].append(0)
+ elif evaluation["evaluation"][metric]["logprobs"] is not None:
+ scores[metric].append(
+ calculate_scores_form_logprobs(evaluation["evaluation"][metric]["logprobs"][0])
+ )
+ else:
+ scores[metric].append(
+ calculate_scores_form_response(evaluation["evaluation"][metric]["response"], evaluation)
+ )
+
+ statistics = {}
+ for metric in metrics:
+ arg_sort = np.argsort(scores[metric])
+ statistics[metric] = {}
+ statistics[metric]["avg_score"] = sum(scores[metric]) / len(data)
+ statistics[metric]["best_3"] = {data[i]["id"]: scores[metric][i] for i in arg_sort[-3:][::-1]}
+ statistics[metric]["worst_3"] = {data[i]["id"]: scores[metric][i] for i in arg_sort[:3]}
+
+ all_statistics[category] = statistics
+
+ jdump(
+ all_statistics,
+ os.path.join(save_path, f"{model_name}_evaluation_statistics.json"),
+ )
+
+
+def analyze_gpt_evaluation_statistics(statistics_path: str, save_path: str) -> None:
+ """
+ Analyze and visualize all GPT evaluation statistics in the given directory.
+
+ Args:
+ statistics_path: path to all the models' statistics.
+ save_path: path to save table and visualization results.
+ """
+
+ if not os.path.exists(statistics_path):
+ raise Exception(f'The given directory "{statistics_path}" doesn\'t exist! No statistics found!')
+
+ all_statistics = {}
+
+ for file_name in os.listdir(statistics_path):
+ if file_name.endswith("_evaluation_statistics.json"):
+ model_name = file_name.split("_evaluation_statistics.json")[0]
+ all_statistics[model_name] = jload(os.path.join(statistics_path, file_name))
+
+ if len(list(all_statistics.keys())) == 0:
+ raise Exception(f'There are no statistics in the given directory "{statistics_path}"!')
+
+ frame_all = {
+ "model": [],
+ "category": [],
+ "metric": [],
+ "avg_score": [],
+ "best_3": [],
+ "worst_3": [],
+ }
+ frame_per_category = {}
+ for model_name, model_statistics in all_statistics.items():
+ for category, category_statistics in model_statistics.items():
+ if frame_per_category.get(category) is None:
+ frame_per_category[category] = {
+ "model": [],
+ "metric": [],
+ "avg_score": [],
+ "best_3": [],
+ "worst_3": [],
+ }
+
+ for metric, metric_statistics in category_statistics.items():
+ frame_all["model"].append(model_name)
+ frame_all["category"].append(category)
+ frame_all["metric"].append(metric)
+ frame_all["avg_score"].append(metric_statistics["avg_score"])
+ frame_all["best_3"].append(metric_statistics["best_3"])
+ frame_all["worst_3"].append(metric_statistics["worst_3"])
+
+ frame_per_category[category]["model"].append(model_name)
+ frame_per_category[category]["metric"].append(metric)
+ frame_per_category[category]["avg_score"].append(metric_statistics["avg_score"])
+ frame_per_category[category]["best_3"].append(metric_statistics["best_3"])
+ frame_per_category[category]["worst_3"].append(metric_statistics["worst_3"])
+
+ if not os.path.exists(save_path):
+ os.makedirs(save_path)
+
+ frame_all = pd.DataFrame(frame_all)
+ frame_all.to_csv(os.path.join(save_path, "gpt_evaluation_statistics.csv"))
+
+ for category in tqdm.tqdm(
+ frame_per_category.keys(),
+ desc=f"GPT evaluation: ",
+ total=len(frame_per_category.keys()),
+ ):
+ data = pd.DataFrame(frame_per_category[category])
+
+ sns.set()
+ fig = plt.figure(figsize=(16, 10))
+ plt.ylim((0, 5))
+
+ fig = sns.barplot(x="metric", y="avg_score", hue="model", data=data, dodge=True)
+ fig.set_title(f"Comparison between Different Models for Category {category.title()}")
+ plt.xlabel("Evaluation Metric")
+ plt.ylabel("Average Score")
+
+ figure = fig.get_figure()
+ figure.savefig(os.path.join(save_path, f"{category}.png"), dpi=400)
+
+ plt.close()
diff --git a/applications/ColossalEval/colossal_eval/evaluate/utils.py b/applications/ColossalEval/colossal_eval/evaluate/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..69fec46705ab63db75e919b38bdfb334817984b4
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/evaluate/utils.py
@@ -0,0 +1,8 @@
+def get_data_per_category(data, categories):
+ data_per_category = {category: [] for category in categories}
+ for item in data:
+ category = item["category"]
+ if category in categories:
+ data_per_category[category].append(item)
+
+ return data_per_category
diff --git a/applications/ColossalEval/colossal_eval/models/__init__.py b/applications/ColossalEval/colossal_eval/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f6c9b414145d3d78f1e684219b5de8c4b8f2b16
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/models/__init__.py
@@ -0,0 +1,5 @@
+from .base import BaseModel
+from .chatglm import ChatGLM2Model, ChatGLMModel
+from .huggingface import HuggingFaceCausalLM, HuggingFaceModel
+
+__all__ = ["BaseModel", "HuggingFaceModel", "HuggingFaceCausalLM", "ChatGLMModel", "ChatGLM2Model"]
diff --git a/applications/ColossalEval/colossal_eval/models/base.py b/applications/ColossalEval/colossal_eval/models/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..aae796c1d56e2b006ff3cd389be052484ce597c9
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/models/base.py
@@ -0,0 +1,78 @@
+from abc import abstractclassmethod
+from typing import Dict, List
+
+from colossal_eval.utils import Conversation, prompt_templates
+
+from colossalai.logging import DistributedLogger
+
+
+class BaseModel:
+ """
+ Base class for model wrapper.
+
+ Args:
+ path: The path to the model.
+ model_max_length: The maximum sequence length of the model.
+ prompt_template: The model's prompt template.
+ batch_size: Batch size for inference.
+ logger: Logger for the model.
+ """
+
+ def __init__(
+ self,
+ path: str,
+ model_max_length: int = 2048,
+ prompt_template: Conversation = None,
+ batch_size: int = 1,
+ logger: DistributedLogger = None,
+ ):
+ self.path = path
+ self.model_max_length = model_max_length
+
+ if prompt_template:
+ self.prompt_template = prompt_template
+ else:
+ self.prompt_template = prompt_templates["plain"]
+
+ self.batch_size = batch_size
+ self.logger = logger
+
+ @abstractclassmethod
+ def inference(self, data: List[Dict]) -> None:
+ """
+ Infer the given data.
+ This function will call self.generate() to get model outputs and also self.model(input) to get logits.
+
+ Args:
+ data: The data for inference.
+ """
+
+ @abstractclassmethod
+ def generate(self, inputs: List[str], max_new_tokens: int) -> List[str]:
+ """
+ Generate results given a list of inputs.
+
+ Args:
+ inputs: A list of strings.
+ max_new_tokens: The maximum length of the output.
+
+ Returns:
+ A list of generated strings.
+ """
+
+ @abstractclassmethod
+ def get_loss(self, batch: List[str], batch_target: List[str]) -> List[float]:
+ """
+ Get loss given batch and batch with target.
+ Use their length difference after tokenization to mask the loss and only compute loss at target tokens.
+
+ Args:
+ batch: batch prompt without target answer.
+ batch_target: batch prompt with target answer.
+
+ Returns:
+ A list of loss.
+ """
+
+ def to(self, device):
+ self.model.to(device)
diff --git a/applications/ColossalEval/colossal_eval/models/chatglm.py b/applications/ColossalEval/colossal_eval/models/chatglm.py
new file mode 100644
index 0000000000000000000000000000000000000000..f293c4f699cda91f98dbd3cc9ddd7821cf8c2753
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/models/chatglm.py
@@ -0,0 +1,303 @@
+import copy
+from typing import List
+
+import torch
+
+from .huggingface import HuggingFaceModel
+
+IGNORE_INDEX = -100
+
+
+class ChatGLMModel(HuggingFaceModel):
+ def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List[str]:
+ truncated_inputs = copy.deepcopy(inputs)
+ # Adapted from https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py#L187
+ for i, input in enumerate(inputs):
+ a_ids = self.tokenizer.encode(text=input, truncation=False, add_special_tokens=False)
+
+ if len(a_ids) > self.model_max_length - max_new_tokens:
+ half = (self.model_max_length - max_new_tokens) // 2
+ prompt = self.tokenizer.decode(a_ids[:half], skip_special_tokens=True) + self.tokenizer.decode(
+ a_ids[-half:], skip_special_tokens=True
+ )
+ truncated_inputs[i] = prompt
+
+ return truncated_inputs
+
+ @torch.no_grad()
+ def get_loss(
+ self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False
+ ) -> List[List[float]]:
+ """
+ Calculate loss only on target tokens.
+
+ Args:
+ batch: A batch of prompt without target answer.
+ batch_target: A batch of target answer. Sometimes one question can have multiple target answers.
+
+ Returns:
+ Loss.
+
+ """
+
+ # We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss.
+ # We don't need to generate new tokens.
+ # Target answer's length is usually << model_max_length, but we still call it in case.
+ # We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens.
+ batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target]
+
+ # Get the number of target answers for different questions
+ batch_target_nums = [len(prompt_target) for prompt_target in batch_target]
+
+ labels_list = []
+ input_ids_list = []
+
+ for input, targets in zip(batch_prompt, batch_target):
+ for target in targets:
+ # Adapted from https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py#L187
+ # If there is no history, the prompt is just the query.
+ # We don't need to override self.generate() in ChatGLM-6B but need to override it in ChatGLM2-6B.
+ # See https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py#L1276
+ target_tokenized = self.tokenizer.encode(text=target, add_special_tokens=False)
+
+ # Get prompt with length model_max_length - len(target_tokenized).
+ # Reserve some space for target answer tokens using max_new_tokens.
+ # This will generate the correct start_idx and end_idx.
+ max_new_tokens = len(target_tokenized)
+
+ # Here 3 tokens are reserved for [gmask_id, bos_token, eos_id]. So we reserve max_new_tokens + 3 tokens.
+ # See https://huggingface.co/THUDM/chatglm-6b/blob/main/tokenization_chatglm.py#L323
+ prompt_with_correct_length = self._get_truncated_prompts([input], max_new_tokens + 3)[0]
+ input_tokenized = self.tokenizer.encode(prompt_with_correct_length, add_special_tokens=False)
+
+ input_ids = self.tokenizer.build_inputs_with_special_tokens(input_tokenized, target_tokenized)
+
+ context_length = input_ids.index(self.tokenizer.bos_token_id)
+ context_length - 1
+
+ target_ids = [IGNORE_INDEX] * len(input_ids)
+
+ # -1 is for eos_token, we don't want to calculate loss on eos token.
+ target_ids[-max_new_tokens - 1 : -1] = input_ids[-max_new_tokens - 1 : -1]
+
+ input_ids_list.append(torch.LongTensor(input_ids))
+ labels_list.append(torch.LongTensor(target_ids))
+
+ # Because of multiple target answers, the final batch size may be greater than self.batch_size.
+ # We will generate new batches.
+ losses = []
+ target_token_nums = []
+
+ batched_input_ids = [
+ input_ids_list[i : i + self.batch_size] for i in range(0, len(input_ids_list), self.batch_size)
+ ]
+ batched_labels = [labels_list[i : i + self.batch_size] for i in range(0, len(labels_list), self.batch_size)]
+
+ for batch_input_ids, batch_labels in zip(batched_input_ids, batched_labels):
+ losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_input_ids, batch_labels)
+ losses.extend(losses_per_batch)
+ target_token_nums.extend(target_token_num_per_batch)
+
+ start_indice = 0
+ losses_per_sample = []
+
+ target_token_nums_per_sample = []
+ for length in batch_target_nums:
+ losses_per_sample.append(losses[start_indice : start_indice + length])
+ target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length])
+ start_indice += length
+
+ return losses_per_sample, target_token_nums_per_sample, None
+
+ def _calculate_loss(self, input_ids_list: List[torch.LongTensor], labels: List[torch.LongTensor]) -> List[float]:
+ """
+ Calculate loss only on target tokens.
+ Hugging Face generate() function can't return per sample loss.
+ It will only return the mean of the loss in a batch.
+ In torch.nn.CrossEntropyLoss(), reduction should be specified as "none" to get per sample loss.
+
+ Args:
+ input_ids_list: A batch of input token ids.
+ labels: A batch of labels.
+
+ Returns:
+ A list of loss.
+
+ """
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id
+ ).to(torch.cuda.current_device())
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to(
+ torch.cuda.current_device()
+ )
+
+ outputs = self.model(input_ids)[0]
+
+ shift_logits = outputs[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+
+ loss_fct = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=IGNORE_INDEX)
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(shift_labels.size())
+
+ lens = (labels != IGNORE_INDEX).sum(-1).cpu().numpy()
+
+ loss_sum = loss.sum(-1).to(torch.float32).cpu().detach().numpy()
+ return loss_sum.tolist(), lens.tolist()
+
+
+class ChatGLM2Model(ChatGLMModel):
+ def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List[str]:
+ truncated_inputs = copy.deepcopy(inputs)
+ # Adapted from https://github.com/THUDM/ChatGLM2-6B/blob/main/ptuning/main.py#L180
+ for i, input in enumerate(inputs):
+ a_ids = self.tokenizer.encode(text=input, add_special_tokens=True, truncation=False)
+
+ if len(a_ids) > self.model_max_length - max_new_tokens:
+ half = (self.model_max_length - max_new_tokens) // 2
+ prompt = self.tokenizer.decode(a_ids[:half], skip_special_tokens=True) + self.tokenizer.decode(
+ a_ids[-half:], skip_special_tokens=True
+ )
+ truncated_inputs[i] = prompt
+
+ return truncated_inputs
+
+ @torch.no_grad()
+ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str]:
+ """Generate results given a list of inputs and get logits of the first new token over choices.
+
+ Args:
+ inputs: A list of strings.
+ max_new_tokens: Max new tokens for generation.
+ kwargs: Key arguments for generation
+
+ Returns:
+ A list of generated strings and logits over choices.
+
+ Note:
+ Currently the function only returns the logits of the first new token.
+ It is used for single choice question.
+ For multiple choices question, please avoid using the loss over choices.
+ You should set argument choices as None in self.inference().
+
+ """
+ # Follow the process of model.chat() method in modeling_chatglm2.py
+ # See https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py#L1020
+ # See https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py#L1001
+
+ query = []
+ for input in inputs:
+ prompt = self.tokenizer.build_prompt(input, None)
+ query.append(prompt)
+
+ truncated_query = self._get_truncated_prompts(query, max_new_tokens)
+
+ encoded_inputs = self.tokenizer(
+ truncated_query,
+ padding=True,
+ truncation=True,
+ return_tensors="pt",
+ max_length=self.model_max_length - max_new_tokens,
+ ).to(torch.cuda.current_device())
+
+ # Set output_scores=True to get prediction scores.
+ outputs = self.model.generate(
+ **encoded_inputs, max_new_tokens=max_new_tokens, return_dict_in_generate=True, output_scores=True, **kwargs
+ )
+
+ # We only need to decode predicted tokens.
+ sequences = outputs.sequences[:, encoded_inputs["input_ids"].shape[1] :]
+
+ scores = []
+ if self.indices_for_choices:
+ # If the question is a single-choice question, we will return the scores of specific indices for first predicted token.
+ # The indices are the tokenization results of the options for the single-choice question.
+ # For example, if the options of the question are A, B, C and D, we only returns scores at indices of A, B, C and D.
+ for option_indices in self.indices_for_choices:
+ scores.append(outputs.scores[0][:, option_indices].detach().cpu())
+
+ scores = torch.max(torch.stack(scores), dim=0)[0]
+
+ decoded_sequences = self.tokenizer.batch_decode(sequences, skip_special_tokens=True)
+
+ return decoded_sequences, scores
+
+ @torch.no_grad()
+ def get_loss(
+ self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False
+ ) -> List[List[float]]:
+ """
+ Calculate loss only on target tokens.
+
+ Args:
+ batch: A batch of prompt without target answer.
+ batch_target: A batch of target answer. Sometimes one question can have multiple target answers.
+
+ Returns:
+ Loss.
+
+ """
+
+ # We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss.
+ # We don't need to generate new tokens.
+ # Target answer's length is usually << model_max_length, but we still call it in case.
+ # We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens.
+ batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target]
+
+ # Get the number of target answers for different questions
+ batch_target_nums = [len(prompt_target) for prompt_target in batch_target]
+
+ labels_list = []
+ input_ids_list = []
+
+ for input, targets in zip(batch_prompt, batch_target):
+ for target in targets:
+ # Adapted from https://github.com/THUDM/ChatGLM2-6B/blob/main/ptuning/main.py#L180
+ prompt = self.tokenizer.build_prompt(input, None)
+
+ target_tokenized = self.tokenizer.encode(
+ text=target, add_special_tokens=False, truncation=True, max_length=self.model_max_length
+ )
+
+ max_new_tokens = len(target_tokenized)
+ prompt_with_correct_length = self._get_truncated_prompts([prompt], max_new_tokens)[0]
+ input_tokenized = self.tokenizer.encode(
+ prompt_with_correct_length,
+ add_special_tokens=True,
+ truncation=True,
+ max_length=self.model_max_length,
+ )
+
+ input_ids = input_tokenized + target_tokenized + [self.tokenizer.eos_token_id]
+ target_ids = [IGNORE_INDEX] * len(input_ids)
+
+ # -1 is for "eos"
+ target_ids[-max_new_tokens - 1 : -1] = input_ids[-max_new_tokens - 1 : -1]
+
+ input_ids_list.append(torch.LongTensor(input_ids))
+ labels_list.append(torch.LongTensor(target_ids))
+
+ # Because of multiple target answers, the final batch size may be greater than self.batch_size.
+ # We will generate new batches.
+ losses = []
+ target_token_nums = []
+
+ batched_input_ids = [
+ input_ids_list[i : i + self.batch_size] for i in range(0, len(input_ids_list), self.batch_size)
+ ]
+ batched_labels = [labels_list[i : i + self.batch_size] for i in range(0, len(labels_list), self.batch_size)]
+
+ for batch_input_ids, batch_labels in zip(batched_input_ids, batched_labels):
+ losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_input_ids, batch_labels)
+ losses.extend(losses_per_batch)
+ target_token_nums.extend(target_token_num_per_batch)
+
+ start_indice = 0
+ losses_per_sample = []
+
+ target_token_nums_per_sample = []
+ for length in batch_target_nums:
+ losses_per_sample.append(losses[start_indice : start_indice + length])
+ target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length])
+ start_indice += length
+
+ return losses_per_sample, target_token_nums_per_sample, None
diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f785a6aa9d1d1eb9b7986780df1ac4ca1ae7f4c
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/models/huggingface.py
@@ -0,0 +1,561 @@
+import copy
+import math
+from typing import Any, Dict, List, Optional, Tuple
+
+import numpy as np
+import torch
+from colossal_eval.utils import Conversation, get_batch_prompt, is_rank_0
+from peft import PeftModel
+from tqdm import tqdm
+from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
+
+from colossalai.logging import DistributedLogger
+
+from .base import BaseModel
+
+IGNORE_INDEX = -100
+
+
+class HuggingFaceModel(BaseModel):
+ """
+ Model wrapper around HuggingFace AutoModel models.
+
+ Args:
+ path: The path to a HuggingFace model.
+ model_max_length: The maximum sequence length of the model.
+ tokenizer_path: The path to the tokenizer.
+ tokenizer_kwargs: Keyword arguments for the tokenizer.
+ peft_path: The name or path to the HuggingFace's PEFT model.
+ model_kwargs: Keyword arguments for the model.
+ prompt_template: The model's prompt template.
+ batch_size: Batch size for inference.
+ logger: Logger for the model.
+
+ """
+
+ def __init__(
+ self,
+ path: str,
+ model_max_length: int = 2048,
+ tokenizer_path: Optional[str] = None,
+ tokenizer_kwargs: dict = dict(),
+ peft_path: Optional[str] = None,
+ model_kwargs: Dict = None,
+ prompt_template: Conversation = None,
+ batch_size: int = 1,
+ logger: DistributedLogger = None,
+ ):
+ super().__init__(
+ path=path,
+ model_max_length=model_max_length,
+ prompt_template=prompt_template,
+ batch_size=batch_size,
+ logger=logger,
+ )
+ self._load_tokenizer(path=path, tokenizer_path=tokenizer_path, tokenizer_kwargs=tokenizer_kwargs)
+
+ self._load_model(path=path, model_kwargs=model_kwargs, peft_path=peft_path)
+
+ def _get_choices_indices(self, language: str):
+ """
+ Get indices for each choice
+
+ Some tokenizer will insert BOS if you don't specify add_special_tokens=False such as Llama-2.
+ The indices for choices may be different given the context. For example, for Llama-2 tokenizer, for Chinese context like "答案:{choice}", indices for choices A, B, C and D are 29909, 29933, 29907 and 29928, for English context like "Answer: {choice}", indices for choices A, B, C and D are 319, 350, 315 and 360.
+ print(self.tokenizer("答案:A")) to see
+ print(self.tokenizer("Answer: A")) to see
+
+ """
+
+ # A trick for get "all" tokens ids related to given choices.
+ self.indices_for_choices = [[] for _ in range(2)]
+ for choice in self.choices:
+ self.indices_for_choices[0].append(
+ self.tokenizer(f"Answer: {choice}", add_special_tokens=False).input_ids[-1]
+ )
+ self.indices_for_choices[1].append(self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1])
+
+ def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kwargs: dict):
+ """
+ Load tokenizer.
+
+ Args:
+ path: The path to the model. Usually it also serves as the path to the tokenizer.
+ tokenizer_path: The path to the tokenzier.
+ tokenizer_kwargs: Keyword arguments for the tokenizer.
+
+ """
+
+ if self.batch_size > 1:
+ tokenizer_kwargs.update({"padding_side": "left"})
+ tokenizer_kwargs.update({"truncation_side": "left"})
+
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path if tokenizer_path else path, **tokenizer_kwargs)
+
+ if self.tokenizer.pad_token_id is None:
+ self.logger.warning("pad_token_id is not set for the tokenizer. " "Using eos_token_id as pad_token_id.")
+ if self.tokenizer.eos_token:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+ elif self.tokenizer.eod_id:
+ # Qwen has an eod token "<|endoftext|>".
+ self.tokenizer.pad_token_id = self.tokenizer.eod_id
+
+ def _load_model(self, path: str, model_kwargs: dict, peft_path: Optional[str] = None):
+ """
+ Load model.
+
+ Args:
+ path: The path to the model.
+ model_kwargs: Keyword arguments for the model.
+ peft_path: The path to the peft model.
+
+ """
+
+ if "torch_dtype" in model_kwargs:
+ model_kwargs["torch_dtype"] = eval(model_kwargs["torch_dtype"])
+
+ model_kwargs.setdefault("torch_dtype", torch.float16)
+
+ self.model = AutoModel.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
+ if peft_path is not None:
+ self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
+ self.model.eval()
+
+ def _calculate_loss(self, input_ids_list: List[torch.LongTensor], labels: List[torch.LongTensor]) -> Tuple[List]:
+ """
+ Calculate loss only on target tokens.
+ Hugging Face generate() function can't return per sample loss.
+ It will only return the mean of the loss in a batch.
+ In torch.nn.CrossEntropyLoss(), reduction should be specified as "none" to get per sample loss.
+
+ Args:
+ input_ids_list: A batch of input token ids.
+ labels: A batch of labels.
+
+ Returns:
+ A list of loss.
+
+ """
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id
+ ).to(torch.cuda.current_device())
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to(
+ torch.cuda.current_device()
+ )
+ attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(torch.cuda.current_device())
+
+ outputs = self.model(input_ids, attention_mask=attention_mask)[0]
+
+ shift_logits = outputs[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+
+ loss_fct = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=IGNORE_INDEX)
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(shift_labels.size())
+
+ lens = (labels != IGNORE_INDEX).sum(-1).cpu().numpy()
+
+ loss_sum = loss.sum(-1).to(torch.float32).cpu().detach().numpy()
+ return loss_sum.tolist(), lens.tolist()
+
+ def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List[str]:
+ """
+ Truncate the input sequence to fit model_max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions)
+ https://github.com/THUDM/LongBench/blob/main/pred.py#L16
+
+ Args:
+ inputs: A batch of input prompts.
+ max_new_tokens: Max new tokens for model to generate.
+
+ Returns:
+ Truncated prompts.
+
+ """
+
+ truncated_inputs = copy.deepcopy(inputs)
+ for i, input in enumerate(inputs):
+ tokenized_prompt = self.tokenizer(input, truncation=False, return_tensors="pt").input_ids[0]
+ if len(tokenized_prompt) > self.model_max_length - max_new_tokens:
+ half = (self.model_max_length - max_new_tokens) // 2
+ prompt = self.tokenizer.decode(
+ tokenized_prompt[:half], skip_special_tokens=True
+ ) + self.tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
+ truncated_inputs[i] = prompt
+
+ return truncated_inputs
+
+ def _get_input_ids_and_labels_pretrain(self, batch_prompt: List[str]) -> Tuple[List[torch.LongTensor]]:
+ """
+ Get input_ids and labels for pretrain data.
+ We only need batch_prompt because for pretain dataset, we don't need to predict new tokens.
+
+ Args:
+ batch_prompt: A batch of prompt.
+
+ Returns:
+ Input_ids and labels for the given batch.
+
+ """
+ input_ids_list = []
+ labels_list = []
+ bytes_list = []
+
+ for input in batch_prompt:
+ # Pretrain data tends to be very long, sometimes much larger than the model_max_length, we only tokenize 1/ratio of the data first to accelerate the tokenization process.
+ # Once the length of the result is greater or equal to model_max_length, we stop iterating on ratios and use the result as input_ids and labels.
+ # After all, the rest of the original string doesn't need to be tokenized at the first place.
+ ratio = [16, 8, 4, 2, 1]
+ tokenized = None
+ for r in ratio:
+ tokenized = self.tokenizer(
+ [input[0 : len(input) // r]], truncation=True, max_length=self.model_max_length, return_tensors="pt"
+ )
+ if tokenized.input_ids.size(1) >= self.model_max_length:
+ break
+
+ input_ids = copy.deepcopy(tokenized["input_ids"])[0]
+ target_ids = copy.deepcopy(input_ids)
+
+ string = self.tokenizer.decode(tokenized.input_ids[0], skip_special_tokens=True)
+
+ bytes_list.append(len(string.encode("utf-8")))
+
+ input_ids_list.append(input_ids)
+ labels_list.append(target_ids)
+
+ return input_ids_list, labels_list, bytes_list
+
+ def _get_input_ids_and_labels(
+ self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool
+ ) -> Tuple[List[torch.LongTensor]]:
+ """
+ Get input_ids and labels for the given data.
+
+ Args:
+ batch_prompt: A batch of prompt.
+ batch_target: A batch of target.
+
+ Returns:
+ Input_ids and labels for the given batch.
+
+ """
+ if pretrain:
+ return self._get_input_ids_and_labels_pretrain(batch_prompt)
+
+ input_ids_list = []
+ labels_list = []
+
+ for input, targets in zip(batch_prompt, batch_target):
+ for target in targets:
+ # TODO: Improve the labeling process. Should annotate the border by adding special tokens.
+ target_tokenized = self.tokenizer(
+ [target], truncation=True, max_length=self.model_max_length, return_tensors="pt"
+ )
+
+ # Get prompt with length model_max_length - len(target_tokenized).
+ # Reserve some space for target answer tokens using max_new_tokens.
+ # This will generate the correct start_idx and end_idx.
+ max_new_tokens = target_tokenized["input_ids"][0].size(0)
+ prompt_with_correct_length = self._get_truncated_prompts([input], max_new_tokens)[0]
+ input_tokenized = self.tokenizer(
+ [prompt_with_correct_length],
+ truncation=True,
+ max_length=self.model_max_length - max_new_tokens,
+ return_tensors="pt",
+ )
+
+ target_tokenized = self.tokenizer(
+ [prompt_with_correct_length + target],
+ truncation=True,
+ max_length=self.model_max_length,
+ return_tensors="pt",
+ )
+
+ start_idx = input_tokenized["input_ids"][0].size(0)
+ end_idx = target_tokenized["input_ids"][0].size(0)
+
+ # Sometimes if the target is only an option such as A, B, C and D, the length of input_tokenized is equal to the length of target_tokenized, so we need -1.
+ # This is caused by the different behavior of tokenizers.
+ # For example, the tokenizer for Baichuan and Llama will cause such problem in a plain prompt setting.
+ # The length of the tokenized sequences for prompt "Answer: " and "Answer: A" is the same.
+ # Baichuan: [29394, 31143, 31106] [29394, 31143, 703]
+ # Llama: [673, 29901, 29871] [673, 29901, 319]
+ # The length for sequence "prompt" and "prompt + A" is equal.
+ # For ChatGLM, the length of the tokenized sequences is different.
+ # ChatGLM: [16583, 12] [16583, 12, 167]
+
+ if start_idx == end_idx:
+ start_idx -= 1
+
+ input_ids = copy.deepcopy(target_tokenized["input_ids"])[0]
+ target_ids = copy.deepcopy(input_ids)
+
+ mask = torch.zeros_like(target_ids, dtype=torch.bool)
+ mask[start_idx:end_idx] = True
+
+ target_ids[~mask] = IGNORE_INDEX
+
+ input_ids_list.append(input_ids)
+ labels_list.append(target_ids)
+
+ return input_ids_list, labels_list, None
+
+ def inference(self, data: List[Dict], inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]:
+ """
+ Infer the given data.
+ This function will call self.generate() to get model outputs and also self.model() to get logits.
+
+ Args:
+ data: The data for inference.
+ inference_kwargs: Arguments for inference.
+ debug: Whether to display generated prompt for debugging.
+
+ Returns:
+ Inference results.
+
+ """
+ calculate_loss = inference_kwargs["calculate_loss"]
+ classes = inference_kwargs["all_classes"]
+ language = inference_kwargs["language"]
+ pretrain = inference_kwargs["pretrain"]
+ max_new_tokens = inference_kwargs["max_new_tokens"]
+ few_shot_data = inference_kwargs.get("few_shot_data", None)
+
+ # Some classification questions' options are texts not a single letter such as A, B, C and D.
+ # If the text length is greater than 1, we won't calculate loss over choices.
+ if classes is not None and any(len(c) > 1 for c in classes):
+ classes = None
+
+ self.choices = classes
+ self.indices_for_choices = None
+ if self.choices:
+ # Get indices for each choice
+ self._get_choices_indices(language)
+
+ self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)}
+
+ bar = tqdm(
+ range(math.ceil(len(data) / self.batch_size)),
+ desc=f"{data[0]['dataset']}-{data[0]['category']} Inference steps",
+ disable=not is_rank_0(),
+ )
+ loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
+
+ answers = copy.deepcopy(data)
+ for i in range(0, len(data), self.batch_size):
+ batch = data[i : i + self.batch_size]
+ batch_prompt, batch_target = get_batch_prompt(
+ self.prompt_template, batch, few_shot_data, self.tokenizer, language, self.model_max_length
+ )
+
+ if is_rank_0() and debug and i == 0:
+ self.logger.info(
+ f"Inference arguments for dataset {data[0]['dataset']} category {data[0]['category']} is:\n{inference_kwargs}"
+ )
+ self.logger.info("-" * 120)
+ self.logger.info("An example prompt and prompt with target is:")
+ self.logger.info("-" * 120)
+ self.logger.info(batch_prompt[0])
+ self.logger.info("-" * 120)
+ self.logger.info(batch_prompt[0] + batch_target[0][0])
+
+ if not pretrain:
+ batch_decodes, scores = self.generate(batch_prompt, max_new_tokens)
+
+ if calculate_loss:
+ batch_losses, batch_target_token_nums, batch_bytes_nums = self.get_loss(
+ batch_prompt, batch_target, pretrain
+ )
+
+ probs = []
+ if self.indices_for_choices:
+ scores = scores.to(torch.float32)
+ # If we have indices_for_choices(must be single-choice question), there will be only one target answer for one data sample.
+ # Otherwise this will violate the single-choice setting.
+
+ if calculate_loss:
+ labels = [self.str_label_map[answers[i + j]["target"]] for j in range(len(batch_decodes))]
+
+ loss_over_choices = loss_fct(scores, torch.tensor(labels, dtype=torch.long)).numpy().tolist()
+
+ probs = torch.nn.functional.softmax(scores, dim=-1).numpy().tolist()
+ probs = [
+ {choice: probs[i][self.str_label_map[choice]] for choice in self.choices} for i in range(len(probs))
+ ]
+
+ for j in range(len(batch_prompt)):
+ if not pretrain:
+ answers[i + j]["output"] = batch_decodes[j].strip()
+
+ if isinstance(scores, torch.Tensor):
+ answers[i + j]["softmax_over_choices"] = probs[j]
+
+ if calculate_loss:
+ answers[i + j]["loss_over_choices"] = loss_over_choices[j]
+
+ if calculate_loss:
+ answers[i + j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist()
+
+ # loss_sum is specially used for pertrain dataset for calculating per-byte-perplexity.
+ # However, loss (which is per sample loss) suffices for most cases.
+ answers[i + j]["loss_sum"] = batch_losses[j]
+ answers[i + j]["token_num"] = batch_target_token_nums[j]
+
+ if batch_bytes_nums:
+ answers[i + j]["byte_num"] = batch_bytes_nums[j]
+
+ bar.update()
+
+ return answers
+
+ @torch.no_grad()
+ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str]:
+ """Generate results given a list of inputs and get logits of the first new token over choices.
+
+ Args:
+ inputs: A list of strings.
+ max_new_tokens: Max new tokens for generation.
+ kwargs: Key arguments for generation
+
+ Returns:
+ A list of generated strings and logits over choices.
+
+ Note:
+ Currently the function only returns the logits of the first new token.
+ It is used for single choice question.
+ For multiple choices question, please avoid using the loss over choices.
+ You should set argument choices as None in self.inference().
+
+ """
+ truncated_inputs = self._get_truncated_prompts(inputs, max_new_tokens)
+
+ encoded_inputs = self.tokenizer(
+ truncated_inputs,
+ padding=True,
+ truncation=True,
+ return_tensors="pt",
+ return_token_type_ids=False,
+ max_length=self.model_max_length - max_new_tokens,
+ ).to(torch.cuda.current_device())
+
+ # Set output_scores=True to get prediction scores.
+ outputs = self.model.generate(
+ **encoded_inputs, max_new_tokens=max_new_tokens, return_dict_in_generate=True, output_scores=True, **kwargs
+ )
+
+ # We only need to decode predicted tokens.
+ sequences = outputs.sequences[:, encoded_inputs["input_ids"].shape[1] :]
+
+ scores = []
+ if self.indices_for_choices:
+ # If the question is a single-choice question, we will return the scores of specific indices for first predicted token.
+ # The indices are the tokenization results of the options for the single-choice question.
+ # For example, if the options of the question are A, B, C and D, we only returns scores at indices of A, B, C and D.
+ for option_indices in self.indices_for_choices:
+ scores.append(outputs.scores[0][:, option_indices].detach().cpu())
+
+ scores = torch.max(torch.stack(scores), dim=0)[0]
+
+ decoded_sequences = self.tokenizer.batch_decode(sequences, skip_special_tokens=True)
+
+ return decoded_sequences, scores
+
+ @torch.no_grad()
+ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool) -> List[List[float]]:
+ """
+ Calculate loss only on target tokens.
+
+ Args:
+ batch: A batch of prompt without target answer.
+ batch_target: A batch of target answer. Sometimes one question can have multiple target answers.
+
+ Returns:
+ Loss.
+
+ """
+
+ # We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss.
+ # We don't need to generate new tokens.
+ # Target answer's length is usually << model_max_length, but we still call it in case.
+ # We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens.
+ if not pretrain:
+ batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target]
+
+ # Get the number of target answers for different questions
+ batch_target_nums = [len(prompt_target) for prompt_target in batch_target]
+
+ input_ids_list, labels_list, bytes_list = self._get_input_ids_and_labels(batch_prompt, batch_target, pretrain)
+
+ # Because of multiple target answers, the final batch size may be greater than self.batch_size.
+ # We will generate new batches.
+ losses = []
+ target_token_nums = []
+
+ batched_input_ids = [
+ input_ids_list[i : i + self.batch_size] for i in range(0, len(input_ids_list), self.batch_size)
+ ]
+ batched_labels = [labels_list[i : i + self.batch_size] for i in range(0, len(labels_list), self.batch_size)]
+
+ for batch_input_ids, batch_labels in zip(batched_input_ids, batched_labels):
+ losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_input_ids, batch_labels)
+ losses.extend(losses_per_batch)
+ target_token_nums.extend(target_token_num_per_batch)
+
+ start_indice = 0
+ losses_per_sample = []
+
+ target_token_nums_per_sample = []
+ bytes_nums_per_sample = []
+ for length in batch_target_nums:
+ losses_per_sample.append(losses[start_indice : start_indice + length])
+ target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length])
+
+ if bytes_list:
+ bytes_nums_per_sample.append(bytes_list[start_indice : start_indice + length])
+
+ start_indice += length
+
+ if bytes_list:
+ return losses_per_sample, target_token_nums_per_sample, bytes_nums_per_sample
+
+ return losses_per_sample, target_token_nums_per_sample, None
+
+
+class HuggingFaceCausalLM(HuggingFaceModel):
+ """
+ Model wrapper around HuggingFace AutoModelForCausalLM models.
+
+ Args:
+ path: The path to a HuggingFace model.
+ model_max_length: The maximum sequence length of the model.
+ tokenizer_path: The path to the tokenizer.
+ tokenizer_kwargs: Keyword arguments for the tokenizer.
+ peft_path: The name or path to the HuggingFace's PEFT model.
+ model_kwargs: Keyword arguments for the model.
+ prompt_template: The model's prompt template.
+ batch_size: Batch size for inference.
+ logger: Logger for the model.
+
+ """
+
+ def _load_model(self, path: str, model_kwargs: dict, peft_path: Optional[str] = None):
+ """
+ Load model.
+
+ Args:
+ path: The path to the model.
+ model_kwargs: Keyword arguments for the model.
+ peft_path: The path to the peft model.
+
+ """
+
+ if "torch_dtype" in model_kwargs:
+ model_kwargs["torch_dtype"] = eval(model_kwargs["torch_dtype"])
+
+ if "config" in model_kwargs:
+ model_kwargs["config"] = AutoConfig.from_pretrained(model_kwargs["config"])
+
+ model_kwargs.setdefault("torch_dtype", torch.float16)
+ self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
+ if peft_path is not None:
+ self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
+ self.model.eval()
diff --git a/applications/ColossalEval/colossal_eval/utils/__init__.py b/applications/ColossalEval/colossal_eval/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5ee6e13b74732d2563715951f29bc134b8563d2
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/utils/__init__.py
@@ -0,0 +1,4 @@
+from .conversation import Conversation, get_batch_prompt, prompt_templates
+from .utilities import get_json_list, is_rank_0, jdump, jload
+
+__all__ = ["Conversation", "prompt_templates", "get_batch_prompt", "is_rank_0", "jload", "jdump", "get_json_list"]
diff --git a/applications/ColossalEval/colossal_eval/utils/conversation.py b/applications/ColossalEval/colossal_eval/utils/conversation.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c096a8523c0ec8df5d8953ff8ba6302c290bb41
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/utils/conversation.py
@@ -0,0 +1,231 @@
+import dataclasses
+from enum import Enum, auto
+from typing import Dict, List, Optional, Tuple
+
+from transformers import AutoTokenizer
+
+
+class SeparatorStyle(Enum):
+ ADD_BOS_EOS_TOKEN = auto()
+ ALPACA = auto()
+ PLAIN = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ system: str
+ roles: List[str]
+ messages: List[List[str]]
+ offset: int
+ sep_style: SeparatorStyle = SeparatorStyle.ADD_BOS_EOS_TOKEN
+ sep: str = ""
+
+ def clear(self):
+ self.messages = []
+
+ def get_prompt(self):
+ if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN:
+ ret = self.system
+ for role, message in self.messages:
+ if message:
+ ret += role + ": " + "" + message + self.sep
+ else:
+ ret += role + ": " + ""
+ return ret
+ elif self.sep_style == SeparatorStyle.ALPACA:
+ ret = self.system + self.sep
+ for role, message in self.messages:
+ if message:
+ ret += role + ":\n" + message + self.sep
+ else:
+ ret += role + ":"
+ return ret
+ elif self.sep_style == SeparatorStyle.PLAIN:
+ ret = self.system
+ for role, message in self.messages:
+ if message:
+ ret += message
+ else:
+ ret += ""
+ return ret
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ def get_prompt_with_target(self, target):
+ prompt = self.get_prompt()
+ prompt_with_target = []
+
+ # Some dataset provides multiple target answers.
+ # This will make it difficult when we calculate loss.
+ # We convert target into list[str] first if the question only has one target answer.
+ target_answers = []
+ if isinstance(target, str):
+ target_answers = [target]
+ else:
+ target_answers = target
+
+ for target_answer in target_answers:
+ if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN:
+ prompt_with_target.append(prompt + target_answer)
+ elif self.sep_style == SeparatorStyle.ALPACA:
+ prompt_with_target.append(prompt + target_answer)
+ elif self.sep_style == SeparatorStyle.PLAIN:
+ prompt_with_target.append(prompt + target_answer)
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ return prompt_with_target
+
+ def save_prompt(self):
+ if self.sep_style == SeparatorStyle.ADD_BOS_EOS_TOKEN:
+ ret = self.system
+ for role, message in self.messages:
+ if message:
+ ret += role + ": " + "" + message + "\n"
+ else:
+ ret += role + ": " + ""
+ return ret
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ def append_message(self, role, message):
+ self.messages.append([role, message])
+
+ def copy(self):
+ return Conversation(
+ system=self.system,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep,
+ )
+
+ def dict(self):
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": self.messages,
+ "offset": self.offset,
+ "sep_style": self.sep_style,
+ "sep": self.sep,
+ }
+
+
+def get_few_shot_prefix(
+ conv: Conversation, few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], language: str, max_tokens: int
+) -> str:
+ """
+ Get few shot prefix.
+
+ Args:
+ conv: Conversation template.
+ few_shot_examples: Few shot examples to generate few shot prompt prefix.
+
+ Returns:
+ Few shot prompt prefix.
+ """
+
+ if language == "English":
+ few_shot_prefix = f"The following are answers for questions in an exam.\n\n"
+ elif language == "Chinese":
+ few_shot_prefix = f"以下是考试中各个问题的答案。\n\n"
+
+ output = None
+ for i in range(len(few_shot_data)):
+ few_shot_prefix = few_shot_prefix + few_shot_data[i] + "\n\n"
+
+ if len(tokenizer([few_shot_prefix]).input_ids[0]) <= max_tokens:
+ output = few_shot_prefix
+ else:
+ break
+
+ return output if output is not None else few_shot_prefix
+
+
+def get_batch_prompt(
+ conv: Conversation,
+ batch: List[Dict],
+ few_shot_data: List[str],
+ tokenizer: Optional[AutoTokenizer],
+ language: Optional[str],
+ model_max_length: Optional[int],
+) -> Tuple[List[Dict], List[Dict]]:
+ """
+ Get batch prompt and target.
+
+ Args:
+ conv: Conversation template.
+ batch: Batch data to generate prompt from.
+ few_shot_data: Few shot data to generate few shot prompt prefix.
+
+ Returns:
+ Tuple containg batch prompt and target.
+
+ """
+
+ batch_prompt = []
+ batch_target = []
+
+ if isinstance(batch[0], dict):
+ for b in batch:
+ few_shot_prefix = ""
+ if few_shot_data is not None:
+ # For few-shot, only need input. Otherwise use instruction (in AGIEval).
+ query_text = b["input"] if b.get("input", "") != "" else b["instruction"]
+
+ if isinstance(b["target"], str):
+ zero_shot_prompt = query_text + b["target"]
+ max_tokens = model_max_length - len(tokenizer([zero_shot_prompt]).input_ids[0])
+ else:
+ raise Exception("When using few-shot, target answer should be a string.")
+
+ few_shot_prefix = get_few_shot_prefix(conv, few_shot_data, tokenizer, language, max_tokens)
+ else:
+ query_text = b["instruction"] + "\n\n" + b["input"] if b.get("input", "") != "" else b["instruction"]
+
+ conv.append_message(conv.roles[0], few_shot_prefix + query_text)
+ conv.append_message(conv.roles[1], None)
+
+ batch_prompt.append(conv.get_prompt())
+
+ target = b["target"]
+ if isinstance(b["target"], str):
+ target = [target]
+
+ batch_target.append(target)
+
+ conv.clear()
+
+ return batch_prompt, batch_target
+
+
+conv_coati = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
+ roles=("Human", "Assistant"),
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
+ sep="",
+)
+
+conv_alpaca = Conversation(
+ system="Below is an instruction that describes a task. Write a response that appropriately completes the request.",
+ roles=("### Instruction", "### Response"),
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.ALPACA,
+ sep="\n\n",
+)
+
+conv_plain = Conversation(
+ system="",
+ roles=("", ""),
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.PLAIN,
+ sep="",
+)
+
+prompt_templates = {"coati": conv_coati, "alpaca": conv_alpaca, "plain": conv_plain}
diff --git a/applications/ColossalEval/colossal_eval/utils/utilities.py b/applications/ColossalEval/colossal_eval/utils/utilities.py
new file mode 100644
index 0000000000000000000000000000000000000000..4eda079074952da15b9b53068b6b2f43b8a3a8e5
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/utils/utilities.py
@@ -0,0 +1,62 @@
+import io
+import json
+import os
+
+import torch.distributed as dist
+
+
+def is_rank_0() -> bool:
+ return not dist.is_initialized() or dist.get_rank() == 0
+
+
+def _make_w_io_base(f, mode: str):
+ if not isinstance(f, io.IOBase):
+ f_dirname = os.path.dirname(f)
+ if f_dirname != "":
+ os.makedirs(f_dirname, exist_ok=True)
+ f = open(f, mode=mode, encoding="utf-8")
+ return f
+
+
+def _make_r_io_base(f, mode: str):
+ if not isinstance(f, io.IOBase):
+ f = open(f, mode=mode, encoding="utf-8")
+ return f
+
+
+def jdump(obj, f, mode="w", indent=4, default=str):
+ """
+ Dump a str or dictionary to a file in json format.
+
+ Args:
+ obj: An object to be written.
+ f: A string path to the location on disk.
+ mode: Mode for opening the file.
+ indent: Indent for storing json dictionaries.
+ default: A function to handle non-serializable entries; defaults to `str`.
+
+ """
+ f = _make_w_io_base(f, mode)
+ if isinstance(obj, (dict, list)):
+ json.dump(obj, f, indent=indent, default=default, ensure_ascii=False)
+ elif isinstance(obj, str):
+ f.write(obj)
+ else:
+ raise ValueError(f"Unexpected type: {type(obj)}")
+ f.close()
+
+
+def jload(f, mode="r"):
+ """Load a .json file into a dictionary."""
+ f = _make_r_io_base(f, mode)
+ jdict = json.load(f)
+ f.close()
+ return jdict
+
+
+def get_json_list(file_path):
+ with open(file_path, "r") as f:
+ json_list = []
+ for line in f:
+ json_list.append(json.loads(line if line != "null" else line))
+ return json_list
diff --git a/applications/ColossalEval/configs/gpt_evaluation/config/config_cn.json b/applications/ColossalEval/configs/gpt_evaluation/config/config_cn.json
new file mode 100644
index 0000000000000000000000000000000000000000..d7c8648810084cdbe3aec95b8045a7d14ac5af6e
--- /dev/null
+++ b/applications/ColossalEval/configs/gpt_evaluation/config/config_cn.json
@@ -0,0 +1,44 @@
+{
+ "language": "cn",
+ "category": {
+ "brainstorming": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "creativity",
+ "practicality",
+ "reasonableness"
+ ]
+ },
+ "chat": {
+ "GPT": [
+ "language organization",
+ "naturalness",
+ "engagingness",
+ "fidelity"
+ ]
+ },
+ "generation": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "diversity"
+ ]
+ },
+ "open_qa": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "correctness"
+ ]
+ },
+ "roleplay": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "fidelity",
+ "creativity"
+ ]
+ }
+ }
+}
diff --git a/applications/ColossalEval/configs/gpt_evaluation/config/config_en.json b/applications/ColossalEval/configs/gpt_evaluation/config/config_en.json
new file mode 100644
index 0000000000000000000000000000000000000000..6ebe3996b1cf72fe75f67c570fad2c857e583158
--- /dev/null
+++ b/applications/ColossalEval/configs/gpt_evaluation/config/config_en.json
@@ -0,0 +1,44 @@
+{
+ "language": "en",
+ "category": {
+ "brainstorming": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "creativity",
+ "practicality",
+ "reasonableness"
+ ]
+ },
+ "chat": {
+ "GPT": [
+ "language organization",
+ "naturalness",
+ "engagingness",
+ "fidelity"
+ ]
+ },
+ "generation": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "diversity"
+ ]
+ },
+ "open_qa": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "correctness"
+ ]
+ },
+ "roleplay": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "fidelity",
+ "creativity"
+ ]
+ }
+ }
+}
diff --git a/applications/ColossalEval/configs/gpt_evaluation/data/eval_cn_examples.json b/applications/ColossalEval/configs/gpt_evaluation/data/eval_cn_examples.json
new file mode 100644
index 0000000000000000000000000000000000000000..f869830555b4b59d22195187f53ee2c25e7881c8
--- /dev/null
+++ b/applications/ColossalEval/configs/gpt_evaluation/data/eval_cn_examples.json
@@ -0,0 +1,202 @@
+[
+ {
+ "category": "brainstorming",
+ "instruction": "列举一些可以促进头发生长的食物。",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 1
+ },
+ {
+ "category": "brainstorming",
+ "instruction": "中年夫妻如何提升夫妻感情,请给出三个实用的的方法,并举例说明。",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 2
+ },
+ {
+ "category": "brainstorming",
+ "instruction": "请列举4种日常的环保行为。",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 3
+ },
+ {
+ "category": "brainstorming",
+ "instruction": "请给出5个可以随时随地锻炼身体的小动作。",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 4
+ },
+ {
+ "category": "brainstorming",
+ "instruction": "请问如何制作一份美味的西红柿炒鸡蛋?",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 5
+ },
+ {
+ "category": "chat",
+ "instruction": "基于以下角色信息完成一段对话。小张是一名新手爱好者,对养鸡有浓厚的兴趣。老李是一名有丰富经验的养鸡大师。",
+ "input": "小张:您好,老李,我最近开始对养鸡感兴趣了,想请教您一些问题。 老李:你好,小张,我很乐意帮助你。你想问些什么? 小张:我想知道如何确定鸡的品种和性别? 老李:确切的品种可以通过鸡的外貌特征来确定,而性别一般是通过鸡卵的大小和形状来判断。还有什么问题吗? 小张:",
+ "output": "",
+ "target": "",
+ "id": 6
+ },
+ {
+ "category": "chat",
+ "instruction": "基于以下角色信息完成一段对话。李华是一名参加了期末考试的学生,他已经很担心自己的考试成绩。老师Lucy正在帮助他度过这个紧张的时刻。",
+ "input": "李华:Lucy老师,我很担心自己的考试成绩,我不知道我是否能够通过这次考试。 Lucy:放松,李华,你已经做好了充分的准备。相信你自己,你会做得很好的。 李华:我很怕考试时会忘记自己所学的知识。 Lucy:你可以预留一些时间,过一遍自己所学的知识点或笔记,这样你会更有信心和准确地回答考题。 李华:如果我还是失败了,该怎么办? Lucy:",
+ "output": "",
+ "target": "",
+ "id": 7
+ },
+ {
+ "category": "chat",
+ "instruction": "基于以下角色信息完成一段对话。张先生是一名企业家,正在考虑是否开拓海外市场;李女士是一名跨境电商专家,擅长国际商务和电子商务。",
+ "input": "张先生:你好,李女士,我正在考虑将我们的产品销售扩大至海外市场,您有什么建议吗? 李女士:您好,张先生,我们需要考虑到海外市场对于产品的需求是否与国内市场一致,需要进行市场调研和定位。然后再进行各种软性、硬性的创新。 张先生:听起来很专业,您能具体解释一下吗? 李女士:",
+ "output": "",
+ "target": "",
+ "id": 8
+ },
+ {
+ "category": "chat",
+ "instruction": "基于以下角色信息完成一段对话。小明是一名医生。一名病患想要提前停药。小王是病患的儿子,希望父亲能够听取医生的建议。",
+ "input": "小明:你好,小王,我了解你想要让你父亲停药。小王:是的,我父亲已经吃了那么久的药,我担心药物对他的身体会有副作用。小明:",
+ "output": "",
+ "target": "",
+ "id": 9
+ },
+ {
+ "category": "chat",
+ "instruction": "基于以下角色信息完成一段对话。张三是一位语文老师,对学生认真负责;李四是张三的学生,对语文兴趣不是很高。",
+ "input": "张三:同学们,今天要讲的是一篇古文《岳阳楼记》。这篇文章非常精彩,希望同学们能够认真听课,理解其中的含义。 李四:怎么又是古文? 张三:",
+ "output": "",
+ "target": "",
+ "id": 10
+ },
+ {
+ "category": "generation",
+ "instruction": "根据主题写一封邮件。",
+ "input": "主题: \"加入我们,共创未来\"",
+ "output": "",
+ "target": "",
+ "id": 11
+ },
+ {
+ "category": "generation",
+ "instruction": "为公司编写一份职场行为准则,包括明确的行为规范和道德准则。",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 12
+ },
+ {
+ "category": "generation",
+ "instruction": "请撰写一篇文章,介绍如何通过改善生活习惯来预防疾病和延长寿命。",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 13
+ },
+ {
+ "category": "generation",
+ "instruction": "请为一家咖啡店编写一篇简短的广告语,吸引更多的顾客。",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 14
+ },
+ {
+ "category": "generation",
+ "instruction": "根据以下故事提示写一篇故事:",
+ "input": "故事提示:```在一个废弃的古堡中,一个小女孩遇到了一只会说话的黑猫,他们一起揭开了一个古老的谜题。```",
+ "output": "",
+ "target": "",
+ "id": 15
+ },
+ {
+ "category": "open_qa",
+ "instruction": "请介绍一下《红楼梦》这部经典小说的故事情节。",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 16
+ },
+ {
+ "category": "open_qa",
+ "instruction": "解释什么是RNA病毒和DNA病毒。",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 17
+ },
+ {
+ "category": "open_qa",
+ "instruction": "什么是比特币?",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 18
+ },
+ {
+ "category": "open_qa",
+ "instruction": "在计算机中,什么是RAM?与ROM有什么区别?",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 19
+ },
+ {
+ "category": "open_qa",
+ "instruction": "请简单介绍一下世界上最长的河流途经的国家。",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 20
+ },
+ {
+ "category": "roleplay",
+ "instruction": "我要你把我写的句子翻译成表情符号。我会写句子,你会用表情符号表达它。我只是想让你用表情符号来表达它。除了表情符号,我不希望你回复任何内容。当我需要用中文告诉你一些事情时,我会用 {} 这样的大括号括起来。我的第一句话是“{我的职业是消防员。}”\n",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 21
+ },
+ {
+ "category": "roleplay",
+ "instruction": "我希望你假定自己是雅思写作考官,根据雅思评判标准,按我给你的雅思考题和对应答案给我评分,并且按照雅思写作评分细则给出打分依据。此外,请给我详细的修改意见并写出满分范文。第一个问题是:It is sometimes argued that too many students go to university, while others claim that a university education should be a universal right. Discuss both sides of the argument and give your own opinion.对于这个问题,我的答案是:In some advanced countries, it is not unusual for more than 50% of young adults to attend college or university. Critics, however, claim that many university courses are worthless and young people would be better off gaining skills in the workplace. In this essay, I will examine both sides of this argument and try to reach a conclusion.There are several reasons why young people today believe they have the right to a university education. First, growing prosperity in many parts of the world has increased the number of families with money to invest in their children’s future. At the same time, falling birthrates mean that one- or two-child families have become common, increasing the level of investment in each child. It is hardly surprising, therefore, that young people are willing to let their families support them until the age of 21 or 22. Furthermore, millions of new jobs have been created in knowledge industries, and these jobs are typically open only to university graduates.However, it often appears that graduates end up in occupations unrelated to their university studies. It is not uncommon for an English literature major to end up working in sales, or an engineering graduate to retrain as a teacher, for example. Some critics have suggested that young people are just delaying their entry into the workplace, rather than developing professional skills.请依次给到我以下内容:具体分数及其评分依据、文章修改意见、满分范文。\n",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 22
+ },
+ {
+ "category": "roleplay",
+ "instruction": "我想让你充当 Linux 终端。我将输入命令,您将回复终端应显示的内容。我希望您只在一个唯一的代码块内回复终端输出,而不是其他任何内容。不要写解释。除非我指示您这样做,否则不要键入命令。当我需要用英语告诉你一些事情时,我会把文字放在中括号内[就像这样]。我的第一个命令是 pwd\n",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 23
+ },
+ {
+ "category": "roleplay",
+ "instruction": "我希望你充当宠物行为主义者。我将为您提供一只宠物和它们的主人,您的目标是帮助主人了解为什么他们的宠物表现出某些行为,并提出帮助宠物做出相应调整的策略。您应该利用您的动物心理学知识和行为矫正技术来制定一个有效的计划,双方的主人都可以遵循,以取得积极的成果。我的第一个请求是“我有一只好斗的德国牧羊犬,它需要帮助来控制它的攻击性。”\n",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 24
+ },
+ {
+ "category": "roleplay",
+ "instruction": "我希望你充当正则表达式生成器。您的角色是生成匹配文本中特定模式的正则表达式。您应该以一种可以轻松复制并粘贴到支持正则表达式的文本编辑器或编程语言中的格式提供正则表达式。不要写正则表达式如何工作的解释或例子;只需提供正则表达式本身。我的第一个提示是生成一个匹配电子邮件地址的正则表达式。\n",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 25
+ }
+]
diff --git a/applications/ColossalEval/configs/gpt_evaluation/data/eval_en_examples.json b/applications/ColossalEval/configs/gpt_evaluation/data/eval_en_examples.json
new file mode 100644
index 0000000000000000000000000000000000000000..27b8af8bc4c6b77628ef72ba57f0869c6abf26a5
--- /dev/null
+++ b/applications/ColossalEval/configs/gpt_evaluation/data/eval_en_examples.json
@@ -0,0 +1,202 @@
+[
+ {
+ "category": "brainstorming",
+ "instruction": "Which are some popular fiction books that I should read?",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 1
+ },
+ {
+ "category": "brainstorming",
+ "instruction": "How do I properly store fruits and vegetables to keep them fresh for longer?",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 2
+ },
+ {
+ "category": "brainstorming",
+ "instruction": "How do you properly chop an onion without crying?",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 3
+ },
+ {
+ "category": "brainstorming",
+ "instruction": "How to make an international transfer? Please provide 3 techniques.",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 4
+ },
+ {
+ "category": "brainstorming",
+ "instruction": "Name five leadership qualities that you consider most important.",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 5
+ },
+ {
+ "category": "chat",
+ "instruction": "Complete a dialogue based on the following character information. Alex: A novice writer who is struggling to find inspiration and develop his writing skills. Emma: A successful author with many published works, providing guidance and advice to Alex.",
+ "input": "Alex: Hi Emma, I have been writing for a while now but can't seem to make any progress. Can you give me any advice? Emma: Hi Alex, sure. What kind of writing are you doing? Alex: I'm trying to write a novel, but I just can't seem to find any inspiration. Emma: ",
+ "output": "",
+ "target": "",
+ "id": 6
+ },
+ {
+ "category": "chat",
+ "instruction": "Complete a dialogue based on the following character information. John: An experienced software engineer with a passion for coding. Karen: A recent college graduate who is interested in learning more about software development.",
+ "input": "Karen: Hi John, I noticed that you have a lot of experience in the software industry. Can you tell me what you think is the most important skill for a software engineer? John: ",
+ "output": "",
+ "target": "",
+ "id": 7
+ },
+ {
+ "category": "chat",
+ "instruction": "Complete a dialogue based on the following character information. Sarah is a new employee who is nervous about her first presentation; Tom is her boss who has given her coaching and preparation materials.",
+ "input": "Sarah: Tom, I'm feeling really nervous about my presentation tomorrow. Tom: I know how you feel, Sarah. However, I believe in you and your abilities. Just stick to the preparation materials that I have given you, and you'll do great. Sarah: Thank you, Tom. What if I forget something important during the presentation? Tom: ",
+ "output": "",
+ "target": "",
+ "id": 8
+ },
+ {
+ "category": "chat",
+ "instruction": "Complete a dialogue based on the following character information. Sarah: a young artist who is full of creative ideas and always eager to try new things. Jack: a seasoned artist who has achieved great success in the art world and is more traditional in his approach to art.",
+ "input": "Sarah: Hi Jack, I'm really excited to meet you. I'm a big fan of your work. Jack: Hi Sarah, nice to meet you too. So, what kind of art do you do? Sarah: I am passionate about abstract art, especially combining different materials and colors. I think it can really give people a new perspective on things. Jack: That's interesting, but I am more focused on realistic paintings. I believe the most important thing is to master the basic skills first. Sarah: ",
+ "output": "",
+ "target": "",
+ "id": 9
+ },
+ {
+ "category": "chat",
+ "instruction": "Complete a conversation based on the following persona information. Sarah is a college student who is interested in joining a volunteer organization. John is the leader of the volunteer organization and is eager to welcome new members.",
+ "input": "Sarah: Hi, I'm Sarah, and I'm interested in joining your volunteer organization. John: Hi Sarah, welcome! We're always looking for new members who are passionate about volunteering. What areas would you like to focus on? Sarah: I'm interested in community outreach and working with children. John: ",
+ "output": "",
+ "target": "",
+ "id": 10
+ },
+ {
+ "category": "generation",
+ "instruction": "Write an email based on the subject:",
+ "input": "Subject: \"Invitation to an Exclusive Webinar\"",
+ "output": "",
+ "target": "",
+ "id": 11
+ },
+ {
+ "category": "generation",
+ "instruction": "Write a set of guidelines for first-time pet owners on how to properly care for a new puppy.",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 12
+ },
+ {
+ "category": "generation",
+ "instruction": "Can you help me write a persuasive speech on why we should recycle more and take better care of the environment?",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 13
+ },
+ {
+ "category": "generation",
+ "instruction": "Write a pitch for a brand-new mobile app that helps people organize their daily tasks more efficiently.",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 14
+ },
+ {
+ "category": "generation",
+ "instruction": "Write a social media post promoting an upcoming concert featuring a local band.",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 15
+ },
+ {
+ "category": "open_qa",
+ "instruction": "Describe the significance of the Renaissance period in European history.",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 16
+ },
+ {
+ "category": "open_qa",
+ "instruction": "What is the term for the surgical removal of the appendix?",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 17
+ },
+ {
+ "category": "open_qa",
+ "instruction": "Explain the process of osmosis in biological systems.",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 18
+ },
+ {
+ "category": "open_qa",
+ "instruction": "Who were the members of the Beatles band?",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 19
+ },
+ {
+ "category": "open_qa",
+ "instruction": "Who painted the The Scream?",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 20
+ },
+ {
+ "category": "roleplay",
+ "instruction": "I want you to act as a linux terminal. I will type commands and you will reply with what the terminal should show. I want you to only reply with the terminal output inside one unique code block, and nothing else. do not write explanations. do not type commands unless I instruct you to do so. when i need to tell you something in english, i will do so by putting text inside curly brackets {like this}. my first command is pwd",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 21
+ },
+ {
+ "category": "roleplay",
+ "instruction": "I want you to act as a travel guide. I will write you my location and you will suggest a place to visit near my location. In some cases, I will also give you the type of places I will visit. You will also suggest me places of similar type that are close to my first location. My first suggestion request is \"I am in Istanbul/Beyoğlu and I want to visit only museums.\"",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 22
+ },
+ {
+ "category": "roleplay",
+ "instruction": "I want you to act as an advertiser. You will create a campaign to promote a product or service of your choice. You will choose a target audience, develop key messages and slogans, select the media channels for promotion, and decide on any additional activities needed to reach your goals. My first suggestion request is \"I need help creating an advertising campaign for a new type of energy drink targeting young adults aged 18-30.\"",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 23
+ },
+ {
+ "category": "roleplay",
+ "instruction": "I want you to act as a storyteller. You will come up with entertaining stories that are engaging, imaginative and captivating for the audience. It can be fairy tales, educational stories or any other type of stories which has the potential to capture people's attention and imagination. Depending on the target audience, you may choose specific themes or topics for your storytelling session e.g., if it’s children then you can talk about animals; If it’s adults then history-based tales might engage them better etc. My first request is \"I need an interesting story on perseverance.\"",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 24
+ },
+ {
+ "category": "roleplay",
+ "instruction": "I want you to act as a rapper. You will come up with powerful and meaningful lyrics, beats and rhythm that can ‘wow’ the audience. Your lyrics should have an intriguing meaning and message which people can relate too. When it comes to choosing your beat, make sure it is catchy yet relevant to your words, so that when combined they make an explosion of sound everytime! My first request is \"I need a rap song about finding strength within yourself.\"",
+ "input": "",
+ "output": "",
+ "target": "",
+ "id": 25
+ }
+]
diff --git a/applications/ColossalEval/configs/gpt_evaluation/prompt/battle_prompt/battle_prompt_cn.json b/applications/ColossalEval/configs/gpt_evaluation/prompt/battle_prompt/battle_prompt_cn.json
new file mode 100644
index 0000000000000000000000000000000000000000..ca66afd7e4644a2854640305feb3dabadc4818b1
--- /dev/null
+++ b/applications/ColossalEval/configs/gpt_evaluation/prompt/battle_prompt/battle_prompt_cn.json
@@ -0,0 +1,6 @@
+{
+ "id": 1,
+ "system_prompt": "你是一个检查回答质量的好助手。",
+ "prompt_template": "[问题]\n{question}\n\n[1号AI助手的答案]\n{answer_1}\n\n[1号AI助手答案终止]\n\n[2号AI助手的答案]\n{answer_2}\n\n[2号AI助手答案终止]\n\n[要求]\n{prompt}\n\n",
+ "prompt": "我们需要你评价这两个AI助手回答的性能。\n请对他们的回答的有用性、相关性、准确性、详细程度进行评分。每个AI助手都会得到一个1到10分的总分,分数越高表示整体表现越好。\n请首先输出一行,该行只包含两个数值,分别表示1号和2号AI助手的分数。这两个分数之间要有一个空格。在随后的一行中,请对你的评价作出全面的解释,避免任何潜在的偏见,并确保AI助手回答的顺序不会影响您的判断。"
+}
diff --git a/applications/ColossalEval/configs/gpt_evaluation/prompt/battle_prompt/battle_prompt_en.json b/applications/ColossalEval/configs/gpt_evaluation/prompt/battle_prompt/battle_prompt_en.json
new file mode 100644
index 0000000000000000000000000000000000000000..2b35d1958ac5e7ef3d00d3a0fbcfdb88e8f06b11
--- /dev/null
+++ b/applications/ColossalEval/configs/gpt_evaluation/prompt/battle_prompt/battle_prompt_en.json
@@ -0,0 +1,6 @@
+{
+ "id": 1,
+ "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer. You will be given two different answers to the same question",
+ "prompt_template": "[Question]\n{question}\n\n[The Start of AI Assistant 1's Answer]\n{answer_1}\n\n[The End of AI Assistant 1's Answer]\n\n[The Start of AI Assistant 2's Answer]\n{answer_2}\n\n[The End of AI Assistant 2's Answer]\n\n[Requirements]\n{prompt}\n\n",
+ "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space. In the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."
+}
diff --git a/applications/ColossalEval/configs/gpt_evaluation/prompt/evaluation_prompt/evaluation_prompt_cn.json b/applications/ColossalEval/configs/gpt_evaluation/prompt/evaluation_prompt/evaluation_prompt_cn.json
new file mode 100644
index 0000000000000000000000000000000000000000..70f6c3ebc31632bb8d2a4ddab1965d693053a60a
--- /dev/null
+++ b/applications/ColossalEval/configs/gpt_evaluation/prompt/evaluation_prompt/evaluation_prompt_cn.json
@@ -0,0 +1,102 @@
+{
+ "brainstorming": {
+ "id": 1,
+ "category": "brainstorming",
+ "metrics": {
+ "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。",
+ "relevance": "切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。",
+ "creativity": "创意性(1-5):某些头脑风暴问题可能需要答案具有创意,提出新的思路。",
+ "practicality": "实用性(1-5):某些头脑风暴问题可能需要答案提出实用的建议或解决方法。",
+ "reasonableness": "合理性(1-5):答案应该符合常识、生活实际等等。"
+ },
+ "CoT": {
+ "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:",
+ "relevance": "1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。\n2. 阅读答案,确认答案是否直接回答了题目所问的问题。\n3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。\n4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。\n\n切题:",
+ "creativity": "1. 仔细阅读所提供的头脑风暴问题,确保你理解问题的要点和背景。\n2. 根据你的知识和经验,判断所提供的答案是否可行。如果答案不可行,则创意性评分可能会受到影响。\n3. 考虑答案中是否包含新颖的想法或独特的思路。答案可能与已知的解决方案有所重叠,但仍然可以被认为是有创意的,只要它提供了新的角度或方法来解决问题。\n4. 根据答案的创意性,给出一个1到5的评分。如果答案缺乏创意,则应给出一个较低的评分。如果答案具有创意并提供了新的思路,应给出一个较高的评分。\n\n创意性:",
+ "practicality": "1. 仔细阅读所提供的头脑风暴问题,确保你理解问题的要点和背景。\n2. 根据你的知识和经验,判断所提供的答案是否可行。如果答案不可行,则实用性评分可能会受到影响。\n3. 考虑答案中提出的建议或解决方法是否实用并可行。答案可能看起来很好,但如果无法实现或应用,则实用性评分可能会受到影响。\n4. 根据答案的实用性,给出一个1到5的评分。如果答案缺乏实用性,则应给出一个较低的评分。如果答案提出了实用的建议或解决方法,并且可以很好地解决问题,则应给出一个较高的评分。\n\n实用性:",
+ "reasonableness": "1. 仔细阅读所提供的头脑风暴问题,确保你理解问题的要点和背景。\n2. 根据你的知识和经验,判断所提供的答案是否可行。如果答案不可行,则合理性评分可能会受到影响。\n3. 考虑答案中所提供的信息是否合理、符合常识、生活实际等等。如果答案中存在明显的不合理之处,则合理性评分可能会受到影响。\n4. 根据答案的合理性,给出一个1到5的评分。如果答案存在明显的不合理之处,则应给出一个较低的评分。如果答案合理、符合常识、生活实际等等,则应给出一个较高的评分。\n\n合理性:"
+ },
+ "prompt": "你是一个好助手。请你为下面“头脑风暴”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}"
+ },
+ "chat": {
+ "id": 2,
+ "category": "chat",
+ "metrics": {
+ "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。",
+ "relevance": "切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。",
+ "naturalness": "自然(1-5):答案是否自然,并且符合问题给定的身份。",
+ "engagingness": "参与感(1-5):答案是否对前面的对话内容做出了恰当的反应,是否理解对话的语境和背景。",
+ "reasonableness": "合理性(1-5):答案是否能够与前面的对话内容形成逻辑上的衔接,是否符合常理,能否在这个上下文中合理存在。",
+ "fidelity": "保真度(1-5):答案是否能够严格遵守角色的设定回答给定的请求。"
+ },
+ "CoT": {
+ "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:",
+ "relevance": "1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。\n2. 阅读答案,确认答案是否直接回答了题目所问的问题。\n3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。\n4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。\n\n切题:",
+ "naturalness": "1. 阅读题目,确定题目提供的身份信息。\n2. 检查答案内容是否符合题目给定的身份。\n3. 根据以上因素,对该回答的自然性进行打分,分数从1到5,其中1表示不自然,5表示非常自然,并符合问题给定的身份。\n\n自然:",
+ "engagingness": "1. 阅读题目,确定对话的语境和背景。\n2. 检查答案是否充分理解对话的语境和背景,能否自然地融入到对话中而不显得突兀。\n3. 根据以上因素,对该回答的参与感进行打分,分数从1到5,其中1表示没有参与感,5表示非常有参与感,并且恰当地理解了对话的语境和背景。\n\n参与感:",
+ "reasonableness": "1. 阅读题目,确定对话的主题以及问题期望的回答方向。\n2. 判断答案是否能够与前面的对话内容形成逻辑上的衔接,是否符合常理,能否在这个上下文中合理存在。\n3. 根据以上因素,对该回答的合理性进行打分,分数从1到5,其中1表示不合理,5表示非常合理,并且能够与前面的对话内容形成逻辑上的衔接,并符合常理。\n\n合理性:",
+ "fidelity": "1. 仔细阅读问题,了解角色在问题中的设定和表现,包括职业、背景、观点、性格等方面。\n阅读题目的请求,确认回答请求时需要注意的细节。\n3. 对比提供的回答与该角色的设定,评估回答是否能够严格遵守角色的设定。\n4. 结合以上评估结果给出保真度的评分,范围从1到5分,其中1分表示回答与角色设定完全不符,5分表示回答完全符合角色设定且满足给定请求。\n\n保真度:"
+ },
+ "prompt": "你是一个好助手。请你为下面的“补全对话”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}"
+ },
+ "generation": {
+ "id": 3,
+ "category": "generation",
+ "metrics": {
+ "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。",
+ "relevance": "切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。",
+ "diversity": "多样性(1-5):答案使用语言是否优美,具有有一定的创造性和想象力。然而,回答也应该保持合理和适度,不要过于夸张或离题。"
+ },
+ "CoT": {
+ "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:",
+ "relevance": "1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。\n2. 阅读答案,确认答案是否直接回答了题目所问的问题。\n3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。\n4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。\n\n切题:",
+ "diversity": "1. 仔细阅读整个回答,确保完全理解回答所表达的内容和主题。\n2. 在阅读回答的同时,注意语言的质量,例如措辞是否正确,语言是否生动等。\n3. 检查回答的创造性和想象力,看看回答是否能够吸引人阅读下去。\n4. 检查回答的合理性和适度,看看回答是否夸张或离题。\n5. 将多样性的评分打分在1到5之间,5分表示回答的质量很好,能够吸引人阅读,1分表示回答的内容生硬或者有离题的问题。\n\n多样性:"
+ },
+ "prompt": "你是一个好助手。请你为下面的“生成”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}"
+ },
+ "open_qa": {
+ "id": 4,
+ "category": "open_qa",
+ "metrics": {
+ "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。",
+ "relevance": "切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。",
+ "correctness": "正确性(1-5):答案是否正确。"
+ },
+ "CoT": {
+ "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:",
+ "relevance": "1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。\n2. 阅读答案,确认答案是否直接回答了题目所问的问题。\n3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。\n4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。\n\n切题:",
+ "correctness": "1. 仔细阅读题目,尝试自己回答该问题。\n2. 检查答案的准确性。您可以使用已知的事实或研究来验证答案是否正确。如果答案是正确的,则可以将正确性得分为5分。如果答案是部分正确的,则可以给予适当的得分,例如2分、3分或4分。如果答案完全不正确,则只得1分。\n\n正确性:"
+ },
+ "prompt": "你是一个好助手。请你为下面的问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}"
+ },
+ "roleplay": {
+ "id": 5,
+ "category": "roleplay",
+ "metrics": {
+ "language organization": "语言组织(1-5):答案语言是否流畅、连贯,使用正确的语法,具有一定逻辑性,使用恰当的连接词、过渡词等等。",
+ "relevance": "切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。",
+ "fidelity": "保真度(1-5):答案是否能够严格遵守角色的设定回答给定的请求。",
+ "creativity": "创意性(1-5):角色扮演问题的回答需要具有一定创意,但同时需要遵守角色的设定。"
+ },
+ "CoT": {
+ "language organization": "1. 阅读答案,并检查是否有语法错误、用词不当或其他显著的错误。\n2. 检查答案是否具有逻辑性,能够按照合理的顺序传达信息并且能够自圆其说。\n3. 确定答案是否与问题或主题相关,并且能够传达清晰的信息。\n4. 检查答案是否连贯,是否使用适当的转换和过渡来保持句子和段落之间的连贯性。\n5. 检查答案是否具有明确的结构和组织方式,使得读者可以轻松理解信息的层次和结构。\n6. 根据以上因素综合评估答案的语言组织,并给出一个1到5的分数,其中5表示语言组织非常好,而1表示语言组织非常差。\n\n语言组织:",
+ "relevance": "1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。\n2. 阅读答案,确认答案是否直接回答了题目所问的问题。\n3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。\n4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。\n\n切题:",
+ "fidelity": "1. 仔细阅读问题,了解角色在问题中的设定和表现,包括职业、背景、观点、性格等方面。\n2. 阅读题目的请求,确认回答请求时需要注意的细节。\n3. 对比提供的回答与该角色的设定,评估回答是否能够严格遵守角色的设定。\n4. 结合以上评估结果给出保真度的评分,范围从1到5分,其中1分表示回答与角色设定完全不符,5分表示回答完全符合角色设定且满足给定请求。\n\n保真度:",
+ "creativity": "1. 仔细阅读问题,了解角色在问题中的设定和表现,包括职业、背景、观点、性格等方面。\n2. 评估回答是否具有独特的思路和建议,是否能够给提问者带来新的想法和启示。\n3. 对比回答中的创意和该角色的设定,评估回答是否遵守了该角色的设定和基本特征。\n4. 对回答的质量进行总体评估,并结合以上评估结果给出创意性的评分,范围从1到5分,其中1分表示回答缺乏创意,5分表示回答具有独特的思路和建议,并且能够遵守该角色的设定。\n\n创意性:"
+ },
+ "prompt": "你是一个好助手。请你为下面的“角色扮演”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}"
+ },
+ "Other": {
+ "id": 6,
+ "category": "Other",
+ "metrics": {
+ "relevance": "切题(1-5):答案内容是否切题,不答非所问,并且严格遵照题目要求。",
+ "correctness": "正确性(1-5):答案是否正确。"
+ },
+ "CoT": {
+ "relevance": "1. 阅读题目,确定题目所问的问题是什么,以及需要回答哪些方面的问题。\n2. 阅读答案,确认答案是否直接回答了题目所问的问题。\n3. 检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。\n4. 根据以上因素综合评估答案的切题程度,并给出一个1到5的分数,其中5表示答案非常切题,而1表示答案完全没有切题。\n\n切题:",
+ "correctness": "1. 仔细阅读题目,尝试自己回答该问题。\n2. 检查答案的准确性。您可以使用已知的事实或研究来验证答案是否正确。如果答案是正确的,则可以将正确性得分为5分。如果答案是部分正确的,则可以给予适当的得分,例如2分、3分或4分。如果答案完全不正确,则只得1分。\n\n正确性:"
+ },
+ "prompt": "你是一个好助手。请你为下面问题的答案打分。\n\n问题如下:\n\n{question}\n\n需要你评分的答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}"
+ }
+}
diff --git a/applications/ColossalEval/configs/gpt_evaluation/prompt/evaluation_prompt/evaluation_prompt_en.json b/applications/ColossalEval/configs/gpt_evaluation/prompt/evaluation_prompt/evaluation_prompt_en.json
new file mode 100644
index 0000000000000000000000000000000000000000..3d04387d98c5d158abb2abbdb2baaad9e59ca89a
--- /dev/null
+++ b/applications/ColossalEval/configs/gpt_evaluation/prompt/evaluation_prompt/evaluation_prompt_en.json
@@ -0,0 +1,103 @@
+{
+ "brainstorming": {
+ "id": 1,
+ "category": "brainstorming",
+ "metrics": {
+ "language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.",
+ "relevance": "Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic.",
+ "creativity": "Creativity (1-5): Some brainstorming questions may require answers that are creative and suggest new ideas.",
+ "practicality": "Practicality (1-5): Some brainstorming questions may require answers to suggest practical suggestions or solutions.",
+ "reasonableness": "Reasonableness (1-5): The answer should be in line with common sense, life experience, etc."
+ },
+ "CoT": {
+ "language organization": "1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.\n2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.\n3. Determine if the answer is relevant to the question or topic and conveys a clear message.\n4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.\n5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.\n6. Evaluate the language organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good language organization and 1 indicates very poor language organization.\n\nLanguage organization:",
+ "relevance": "1. Read the question to determine what the question asks and what aspects of the question need to be answered.\n2. Read the answers to make sure that they directly answer the question asked.\n3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.\n4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all.\n\nRelevance:",
+ "creativity": "1. Read the provided brainstorming questions carefully to make sure you understand the gist and context of the questions.\n2. Based on your knowledge and experience, determine if the answers provided are feasible. If the answer is not feasible, the creativity score may be affected.\n3. Consider whether the answer contains novel ideas or unique thoughts. An answer may overlap with a known solution and still be considered creative, as long as it offers a new perspective or approach to the problem.\n4. Give a score of 1 to 5 depending on the creativity of the answer. If the answer lacks creativity, a lower score should be given. If the answer is creative and provides a new idea, a higher score should be given.\n\nCreativity:",
+ "practicality": "1. Read the provided brainstorming questions carefully to make sure you understand the gist and context of the questions.\n2. Based on your knowledge and experience, determine if the answers provided are feasible. If the answer is not feasible, the practicality score may be affected.\n3. Consider whether the suggestions or solutions presented in the answer are practical and workable. The answer may look good, but if it cannot be implemented or applied, the practicality score may be affected.\n4. Give a score of 1 to 5 depending on the practicality of the answer. If the answer lacks practicality, a lower score should be given. If the answer makes a practical suggestion or solution and solves the problem well, a higher score should be given.\n\nPracticality:",
+ "reasonableness": "1. Read the provided brainstorming questions carefully to make sure you understand the gist and context of the questions.\n2. Based on your knowledge and experience, determine if the answers provided are feasible. If the answer is not feasible, the reasonableness score may be affected.\n3. Consider whether the information provided in the answer is reasonable, consistent with common sense, real life, etc. If there are obvious errors or implausibilities in the answer, the reasonableness score may be affected.\n4. Give a score of 1 to 5 depending on the reasonableness of the answer. If the answer contains obvious errors or unreasonable points, a lower score should be given. A higher score should be given if the answer is reasonable, consistent with common sense, real life, etc.\n\nReasonableness:"
+ },
+ "prompt": "You are a good assistant. Please rate the given answer to the \"brainstorming\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}"
+ },
+ "chat": {
+ "id": 2,
+ "category": "chat",
+ "metrics": {
+ "language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.",
+ "relevance": "Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic.",
+ "naturalness": "Naturalness (1-5): whether the answer is natural and fits the identity given by the question.",
+ "engagingness": "Engagingness (1-5): whether the answer responds appropriately to the content of the preceding conversation and whether it understands the context and background of the conversation.",
+ "reasonableness": "Reasonableness (1-5): Whether the answer can form a logical connection with the content of the previous dialogue, whether it is consistent with common sense, and whether it can reasonably exist in this context.",
+ "fidelity": "Fidelity (1-5): whether the answer is able to answer the given request in strict compliance with the role setting."
+ },
+ "CoT": {
+ "language organization": "1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.\n2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.\n3. Determine if the answer is relevant to the question or topic and conveys a clear message.\n4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.\n5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.\n6. Evaluate the language organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good language organization and 1 indicates very poor language organization.\n\nLanguage organization:",
+ "relevance": "1. Read the question to determine what the question asks and what aspects of the question need to be answered.\n2. Read the answers to make sure that they directly answer the question asked.\n3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.\n4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all.\n\nRelevance:",
+ "naturalness": "1. Read the question and determine the identity information provided in the question.\n2. Check whether the content of the answer matches the identity given in the question.\n3. Based on the above factors, score the naturalness of the response on a scale from 1 to 5, where 1 means unnatural and 5 means very natural and in accordance with the identity given in the question.\n\nNaturalness:",
+ "engagingness": "1. Read the questions to determine the context and background of the dialogue.\n2. Check that the answer fully understands the context and background of the conversation and that it fits naturally into the conversation without seeming abrupt.\n3. Based on the above factors, rate the response's engagement on a scale from 1 to 5, where 1 means not engaged and 5 means very engaged and appropriately understands the context and background of the conversation.\n\nEngagingness:",
+ "reasonableness": "1. Read the question and determine the topic of the conversation and the direction the question expects the answer to go.\n2. Determine whether the answer can be logically connected to the preceding conversation, whether it makes common sense, and whether it can reasonably exist in this context.\n3. Based on the above factors, rate the reasonableness of the answer on a scale from 1 to 5, where 1 means unreasonable and 5 means very reasonable and able to form a logical connection with the preceding dialogue content and consistent with common sense.\n\nReasonableness:",
+ "fidelity": "1. Read the question carefully to understand how the character is set up and represented in the question, including aspects such as occupation, background, point of view, and personality.\n2. Read the question's request and confirm the details that need to be taken into account when answering the request.\n3. Compare the provided answer with the setting of the role and assess whether the answer can strictly adhere to the setting of the role.\n4. Combine the results of the above assessment to give a fidelity score ranging from 1 to 5, where a score of 1 means that the response does not match the persona at all, and a score of 5 means that the response fully complies with the persona and satisfies the given request.\n\nFidelity:"
+ },
+ "prompt": "You are a good assistant. Please rate the given answer to the \"chat\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}"
+ },
+ "generation": {
+ "id": 3,
+ "category": "generation",
+ "metrics": {
+ "language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.",
+ "relevance": "Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic.",
+ "diversity": "Diversity (1-5): Whether the answers use beautiful language and have some creativity and imagination. However, answers should also be kept reasonable and moderate, not overly exaggerated or off-topic."
+ },
+ "CoT": {
+ "language organization": "1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.\n2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.\n3. Determine if the answer is relevant to the question or topic and conveys a clear message.\n4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.\n5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.\n6. Evaluate the language organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good language organization and 1 indicates very poor language organization.\n\nLanguage organization:",
+ "relevance": "1. Read the question to determine what the question asks and what aspects of the question need to be answered.\n2. Read the answers to make sure that they directly answer the question asked.\n3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.\n4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all.\n\nRelevance:",
+ "diversity": "1. Read the entire response carefully to ensure that you fully understand the content and theme expressed in the response.\n2. While reading the response, pay attention to the quality of the language, such as whether the wording is correct and the language is vivid.\n3. Check the creativity and imagination of the response to see if the response is engaging to read on.\n4. Check the reasonableness and appropriateness of the responses to see if the responses are exaggerated or off-topic.\n5. Rate the diversity on a scale of 1 to 5, with a 5 indicating a good quality response that is engaging to read and a 1 indicating a raw response or a question that is off-topic.\n\nDiversity:"
+ },
+ "prompt": "You are a good assistant. Please rate the given answer to the \"generation\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}"
+ },
+ "open_qa": {
+ "id": 4,
+ "category": "open_qa",
+ "metrics": {
+ "language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.",
+ "relevance": "Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic.",
+ "correctness": "Correctness (1-5): whether the answer is correct or not."
+ },
+ "CoT": {
+ "language organization": "1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.\n2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.\n3. Determine if the answer is relevant to the question or topic and conveys a clear message.\n4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.\n5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.\n6. Evaluate the language organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good language organization and 1 indicates very poor language organization.\n\nLanguage organization:",
+ "relevance": "1. Read the question to determine what the question asks and what aspects of the question need to be answered.\n2. Read the answers to make sure that they directly answer the question asked.\n3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.\n4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all.\n\nRelevance:",
+ "correctness": "1. Read the question carefully and try to answer the question yourself.\n2. Check the correctness of the answer. You can use known facts or research to verify that the answer is correct. If the answer is correct, you can give a score of 5 for correctness. If the answer is partially correct, an appropriate score, such as 2, 3, or 4, may be given. If the answer is completely incorrect, only 1 point is awarded.\n\nCorrectness:"
+ },
+ "prompt": "You are a good assistant. Please rate the answers to the \"open qa\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}"
+ },
+ "roleplay": {
+ "id": 5,
+ "category": "roleplay",
+ "metrics": {
+ "language organization": "Language organization (1-5): whether the answer language is fluent and coherent, uses correct grammar, has a certain logic, uses appropriate connecting words, transition words, etc.",
+ "relevance": "Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic.",
+ "fidelity": "Fidelity (1-5): whether the answer is able to answer the given request in strict compliance with the role setting.",
+ "creativity": "Creativity (1-5): The answers to the role-play questions need to be somewhat creative, but at the same time they need to adhere to the setting of the role."
+ },
+ "CoT": {
+ "language organization": "1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.\n2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.\n3. Determine if the answer is relevant to the question or topic and conveys a clear message.\n4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.\n5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.\n6. Evaluate the language organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good language organization and 1 indicates very poor language organization.\n\nLanguage organization:",
+ "relevance": "1. Read the question to determine what the question asks and what aspects of the question need to be answered.\n2. Read the answers to make sure that they directly answer the question asked.\n3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.\n4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all.\n\nRelevance:",
+ "fidelity": "1. Read the question carefully to understand how the character is set up and represented in the question, including aspects such as occupation, background, point of view, and personality.\n2. Read the question's request and confirm the details that need to be taken into account when answering the request.\n3. Compare the provided answer with the setting of the role and assess whether the answer can strictly adhere to the setting of the role.\n4. Combine the results of the above assessment to give a fidelity score ranging from 1 to 5, where a score of 1 means that the response does not match the persona at all, and a score of 5 means that the response fully complies with the persona and satisfies the given request.\n\nFidelity:",
+ "creativity": "1. Read the question carefully to understand how the character is set up and represented in the question, including career, background, perspective, and personality.\n2. Evaluate whether the answer has unique ideas and suggestions that bring new ideas and insights to the questioner.\n3. Compare the creativity in the response to the setting of the persona and assess whether the response adheres to the setting and essential characteristics of the persona.\n4. Evaluate the quality of the responses in general and combine the results of the above assessment to give a creativity score ranging from 1 to 5, where a score of 1 indicates that the response lacks creativity and a score of 5 indicates that the response has unique ideas and suggestions and is able to adhere to the set-up of the persona.\n\nCreativity:"
+ },
+ "prompt": "You are a good assistant. Please rate the given answer to the \"role-play\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}"
+ },
+ "Other": {
+ "id": 6,
+ "category": "Other",
+ "metrics": {
+ "relevance": "Relevance (1-5): whether the content of the answer is relevant to the topic, does not answer the wrong question, and strictly follows the requirements of the topic.",
+ "correctness": "Correctness (1-5): whether the answer is correct or not."
+ },
+ "CoT": {
+ "language organization": "1. Read the answers and check for grammatical errors, poor word choice, or other significant mistakes.\n2. Check that the answer is logical, conveys the information in a logical order, and is self-explanatory.\n3. Determine if the answer is relevant to the question or topic and conveys a clear message.\n4. Check that the answer is coherent and that appropriate transitions and switches are used to maintain coherence between sentences and paragraphs.\n5. Check that the answer is clearly structured and organized in such a way that the reader can easily understand the hierarchy and structure of the information.\n6. Evaluate the language organization of the answer based on a combination of the above factors and give a score of 1 to 5, where 5 indicates very good language organization and 1 indicates very poor language organization.\n\nLanguage organization:",
+ "relevance": "1. Read the question to determine what the question asks and what aspects of the question need to be answered.\n2. Read the answers to make sure that they directly answer the question asked.\n3. Check that the answer follows the requirements of the question, including the way it is answered, the length of the answer, the format of the answer, etc.\n4. Evaluate how relevant the answer is based on the above factors and give a score of 1 to 5, where 5 means the answer is very relevant and 1 means the answer is not relevant at all.\n\nRelevance:",
+ "correctness": "1. Read the question carefully and try to answer the question by yourself.\n2. Check the correctness of the answer. You can use known facts or research to verify that the answer is correct. If the answer is correct, you can give a score of 5 for correctness. If the answer is partially correct, an appropriate score, such as 2, 3, or 4, may be assigned. If the answer is completely incorrect, only 1 point is awarded.\n\nCorrectness:"
+ },
+ "prompt": "You are a good assistant. Please rate the given answer to the question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}"
+ }
+}
diff --git a/applications/ColossalEval/examples/dataset_evaluation/config/evaluation/config.json b/applications/ColossalEval/examples/dataset_evaluation/config/evaluation/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..adb540f60345bae31c1451c62ac24aca0bbc868b
--- /dev/null
+++ b/applications/ColossalEval/examples/dataset_evaluation/config/evaluation/config.json
@@ -0,0 +1,58 @@
+{
+ "model": [
+ {
+ "name": "model1"
+ },
+ {
+ "name": "model2"
+ }
+ ],
+ "dataset": [
+ {
+ "name": "mmlu",
+ "metrics": [
+ "first_token_accuracy",
+ "single_choice_accuracy",
+ "perplexity",
+ "ppl_score",
+ "ppl_score_over_choices"
+ ]
+ },
+ {
+ "name": "cmmlu",
+ "metrics": [
+ "first_token_accuracy",
+ "single_choice_accuracy",
+ "perplexity",
+ "ppl_score",
+ "ppl_score_over_choices"
+ ]
+ },
+ {
+ "name": "agieval",
+ "metrics": [
+ "first_token_accuracy",
+ "single_choice_accuracy",
+ "multi_choice_accuracy",
+ "math_equivalence",
+ "perplexity",
+ "ppl_score_over_choices",
+ "ppl_score"
+ ]
+ },
+ {
+ "name": "gaokaobench",
+ "metrics": [
+ "first_token_accuracy",
+ "single_choice_accuracy",
+ "multi_choice_accuracy",
+ "math_equivalence",
+ "rouge_score",
+ "rouge_zh_score",
+ "perplexity",
+ "ppl_score_over_choices",
+ "ppl_score"
+ ]
+ }
+ ]
+}
diff --git a/applications/ColossalEval/examples/dataset_evaluation/config/inference/config.json b/applications/ColossalEval/examples/dataset_evaluation/config/inference/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..9672c442e647b9fcc186133906b17c102a5ecd08
--- /dev/null
+++ b/applications/ColossalEval/examples/dataset_evaluation/config/inference/config.json
@@ -0,0 +1,84 @@
+{
+ "model": [
+ {
+ "name": "model name",
+ "model_class": "HuggingFaceCausalLM",
+ "parameters": {
+ "path": "path to model",
+ "model_max_length": 4096,
+ "tokenizer_path": "",
+ "tokenizer_kwargs": {
+ "trust_remote_code": true
+ },
+ "peft_path": null,
+ "model_kwargs": {
+ "torch_dtype": "torch.float32",
+ "trust_remote_code": true
+ },
+ "prompt_template": "plain",
+ "batch_size": 4
+ }
+ },
+ {
+ "name": "model2 name",
+ "model_class": "HuggingFaceCausalLM",
+ "parameters": {
+ "path": "path to model2",
+ "model_max_length": 4096,
+ "tokenizer_path": "",
+ "tokenizer_kwargs": {
+ "trust_remote_code": true
+ },
+ "peft_path": null,
+ "model_kwargs": {
+ "torch_dtype": "torch.float32",
+ "trust_remote_code": true
+ },
+ "prompt_template": "plain",
+ "batch_size": 4
+ }
+ }
+ ],
+ "dataset": [
+ {
+ "name": "agieval",
+ "dataset_class": "AGIEvalDataset",
+ "debug": false,
+ "few_shot": false,
+ "path": "path to original dataset (folder)",
+ "save_path": "path to save converted dataset (e.g. inference_data/agieval.json)"
+ },
+ {
+ "name": "ceval",
+ "dataset_class": "CEvalDataset",
+ "debug": false,
+ "few_shot": true,
+ "path": "path to original dataset (folder)",
+ "save_path": "path to save converted dataset (e.g. inference_data/ceval.json)"
+ },
+ {
+ "name": "cmmlu",
+ "dataset_class": "CMMLUDataset",
+ "debug": false,
+ "few_shot": true,
+ "path": "path to original dataset (folder)",
+ "save_path": "path to save converted dataset (e.g. inference_data/cmmlu.json)"
+ },
+ {
+ "name": "gaokaobench",
+ "dataset_class": "GaoKaoBenchDataset",
+ "debug": false,
+ "few_shot": false,
+ "path": "path to original dataset (folder)",
+ "save_path": "path to save converted dataset (e.g. inference_data/gaokaobench.json)"
+ },
+ {
+ "name": "mmlu",
+ "dataset_class": "MMLUDataset",
+ "debug": false,
+ "few_shot": true,
+ "path": "path to original dataset (folder)",
+ "save_path": "path to save converted dataset (e.g. inference_data/mmlu.json)"
+ }
+ ]
+}
diff --git a/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.py b/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec81cf0cef71dc9dece6e649b6923f48f3c459d5
--- /dev/null
+++ b/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.py
@@ -0,0 +1,73 @@
+import argparse
+import os
+
+import tabulate
+from colossal_eval.evaluate.dataset_evaluator import DatasetEvaluator
+from colossal_eval.utils import jdump, jload
+
+
+def main(args):
+ config = jload(args.config)
+
+ evaluation_results = {dataset["name"]: {} for dataset in config["dataset"]}
+ evaluation_results_table = {dataset["name"]: {} for dataset in config["dataset"]}
+ evaluator = DatasetEvaluator()
+
+ for dataset_parameter in config["dataset"]:
+ dataset_name = dataset_parameter["name"]
+ metrics = dataset_parameter["metrics"]
+ results_metric_model = {metric: {model["name"]: None for model in config["model"]} for metric in metrics}
+ for model in config["model"]:
+ model_name = model["name"]
+
+ data = jload(
+ os.path.join(args.inference_results_path, model_name, f"{dataset_name}_inference_results.json")
+ )
+ results = evaluator.get_evaluation_results(data, dataset_name, model_name, metrics)
+
+ for metric, score in results.items():
+ results_metric_model[metric][model_name] = score["ALL"]
+
+ evaluation_results[dataset_name][model_name] = results
+
+ evaluation_results_table[dataset_name] = results_metric_model
+
+ table = []
+ header = ["dataset", "metric"] + [model["name"] for model in config["model"]]
+ table.append(header)
+
+ for dataset_parameter in config["dataset"]:
+ dataset_name = dataset_parameter["name"]
+ metrics = dataset_parameter["metrics"]
+
+ for metric, model_results in evaluation_results_table[dataset_name].items():
+ row = [dataset_name]
+ for model, score in model_results.items():
+ if len(row) == 1:
+ row.extend([metric, "{:.02f}".format(score)])
+ else:
+ row.append("{:.02f}".format(score))
+
+ table.append(row)
+
+ table = tabulate.tabulate(table, headers="firstrow")
+ print(table)
+
+ os.makedirs(args.evaluation_results_save_path, exist_ok=True)
+
+ with open(os.path.join(args.evaluation_results_save_path, "evaluation_results_table.txt"), "w") as file:
+ file.write(table)
+
+ jdump(evaluation_results, os.path.join(args.evaluation_results_save_path, "evaluation_results.json"))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="ColossalEval evaluation process.")
+ parser.add_argument("--config", type=str, default=None, required=True, help="path to config file")
+ parser.add_argument("--inference_results_path", type=str, default=None, help="path to inference results")
+ parser.add_argument(
+ "--evaluation_results_save_path", type=str, default=None, help="path to save evaluation results"
+ )
+ args = parser.parse_args()
+
+ main(args)
diff --git a/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.sh b/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ad0bfc03acbbb0455b11ca62d216fcc9b3f75594
--- /dev/null
+++ b/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.sh
@@ -0,0 +1,4 @@
+python eval_dataset.py \
+ --config "path to config file" \
+ --inference_results_path "path to inference results" \
+ --evaluation_results_save_path "path to save evaluation results"
diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..657fc33bf1ef44e54a49a43f1159b177549bf0a6
--- /dev/null
+++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py
@@ -0,0 +1,171 @@
+import argparse
+import copy
+import os
+from typing import Dict, List
+
+import torch
+import torch.distributed as dist
+from colossal_eval import dataset, models, utils
+
+import colossalai
+from colossalai.logging import get_dist_logger
+
+logger = get_dist_logger()
+
+
+def rm_and_merge(world_size: int, save_path: str, model_names: List[str], dataset_names: Dict[str, List]) -> None:
+ """
+ Remove inference result per rank and merge them into one file.
+
+ Args:
+ world_size: Number of processes for inference.
+ save_path: The folder for storing inference results.
+ model_names: Names of models for inference.
+ dataset_names: Names of dataset for inference.
+
+ """
+
+ for model_name in model_names:
+ for dataset_name, categories in dataset_names.items():
+ all_answers = {}
+ for category in categories:
+ all_answers[category] = {"data": []}
+ answers = {"data": []}
+
+ for r in range(world_size):
+ directory = os.path.join(
+ save_path, model_name, f"{dataset_name}_{category}_inference_results_rank{r}.json"
+ )
+ if not os.path.exists(directory):
+ raise Exception(
+ f"Directory {directory} not found. There may be an error during inference time."
+ )
+ else:
+ rank_answers = utils.jload(directory)
+ answers["data"].extend(rank_answers["data"])
+ answers["inference_kwargs"] = rank_answers["inference_kwargs"]
+
+ for r in range(world_size):
+ try:
+ directory = os.path.join(
+ save_path, model_name, f"{dataset_name}_{category}_inference_results_rank{r}.json"
+ )
+ os.remove(directory)
+ except Exception as e:
+ print(e)
+
+ all_answers[category] = answers
+
+ logger.info(f"Save inference results of model {model_name} on dataset {dataset_name}.")
+ utils.jdump(all_answers, os.path.join(save_path, model_name, f"{dataset_name}_inference_results.json"))
+
+ logger.info(f"Save inference results of model {model_name} for all dataset.")
+ logger.info(f"Save inference results of all models for all dataset.")
+
+
+def main(args):
+ colossalai.launch_from_torch(config={}, seed=42)
+ world_size = dist.get_world_size()
+ rank = dist.get_rank()
+
+ inference_data = {}
+ debug_args = {}
+ few_shot_args = {}
+
+ config = utils.jload(args.config)
+
+ model_parameters = config["model"]
+ dataset_parameters = config["dataset"]
+
+ for dataset_parameter in dataset_parameters:
+ path = dataset_parameter["path"]
+ save_path = dataset_parameter["save_path"]
+ dataset_name = dataset_parameter["name"]
+ debug_args[dataset_name] = dataset_parameter["debug"]
+ few_shot_args[dataset_name] = dataset_parameter["few_shot"]
+
+ if not args.load_dataset:
+ if os.path.exists(save_path):
+ dataset_ = utils.jload(save_path)
+ inference_data[dataset_name] = dataset_["test"]
+ else:
+ raise Exception(
+ "Can't find the converted dataset. You may set load_dataset True to store the dataset first."
+ )
+
+ continue
+
+ dataset_class = eval(f"dataset.{dataset_parameter['dataset_class']}")
+ if not issubclass(dataset_class, dataset.BaseDataset):
+ raise ValueError(f"Dataset class {dataset_parameter['dataset_class']} is not a subclass of BaseDataset.")
+
+ dataset_ = dataset_class(path, logger, dataset_parameter["few_shot"])
+
+ dataset_.save(save_path)
+ inference_data[dataset_name] = dataset_.dataset["test"]
+
+ for model_parameter in model_parameters:
+ model_name = model_parameter["name"]
+ model_class = eval(f"models.{model_parameter['model_class']}")
+ paramerters = model_parameter["parameters"]
+ paramerters.update({"logger": logger})
+ paramerters.update({"prompt_template": utils.prompt_templates[paramerters["prompt_template"]]})
+
+ model_ = model_class(**paramerters)
+ if not issubclass(model_class, models.BaseModel):
+ raise ValueError(f"Model class {model_parameter['model_class']} is not a subclass of BaseModel.")
+
+ for dataset_name, split_data in inference_data.items():
+ start = 0
+ for category, category_data in split_data.items():
+ if few_shot_args[dataset_name] and category_data["inference_kwargs"].get("few_shot_data", None) is None:
+ raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!")
+
+ answers_to_dump = copy.deepcopy(category_data)
+ partition_size = len(category_data["data"]) // world_size
+ redundant = len(category_data["data"]) % world_size
+
+ # Ensure that the amount of data for inference is as consistent as possible across different processes.
+ lengths = [partition_size for _ in range(world_size)]
+ for j in range(redundant):
+ lengths[(j + start) % world_size] += 1
+
+ start = (start + redundant) % world_size
+
+ questions = category_data["data"][sum(lengths[0:rank]) : sum(lengths[0:rank]) + lengths[rank]]
+
+ answers_per_rank = model_.inference(
+ questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name]
+ )
+
+ answers_to_dump["data"] = answers_per_rank
+
+ utils.jdump(
+ answers_to_dump,
+ os.path.join(
+ args.inference_save_path,
+ model_name,
+ f"{dataset_name}_{category}_inference_results_rank{rank}.json",
+ ),
+ )
+
+ logger.info(f"Rank {rank} peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB")
+
+ del model_
+ torch.cuda.empty_cache()
+
+ dist.barrier()
+ if rank == 0:
+ model_names = [model_parameter["name"] for model_parameter in model_parameters]
+ dataset_names = {key: list(inference_data[key].keys()) for key in inference_data}
+ rm_and_merge(world_size, args.inference_save_path, model_names, dataset_names)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="ColossalEval inference process.")
+ parser.add_argument("--config", type=str, default=None, required=True, help="path to config file")
+ parser.add_argument("--load_dataset", default=False, action="store_true")
+ parser.add_argument("--inference_save_path", type=str, default=None, help="path to save inference results")
+ args = parser.parse_args()
+
+ main(args)
diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.sh b/applications/ColossalEval/examples/dataset_evaluation/inference.sh
new file mode 100644
index 0000000000000000000000000000000000000000..15f9afd560454a61ae719b1ac5cf7c3882ed3a99
--- /dev/null
+++ b/applications/ColossalEval/examples/dataset_evaluation/inference.sh
@@ -0,0 +1,4 @@
+torchrun --nproc_per_node=1 inference.py \
+ --config "path to config file" \
+ --load_dataset \
+ --inference_save_path "path to save inference results"
diff --git a/applications/ColossalEval/examples/gpt_evaluation/config/evaluation/config.json b/applications/ColossalEval/examples/gpt_evaluation/config/evaluation/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..6ebe3996b1cf72fe75f67c570fad2c857e583158
--- /dev/null
+++ b/applications/ColossalEval/examples/gpt_evaluation/config/evaluation/config.json
@@ -0,0 +1,44 @@
+{
+ "language": "en",
+ "category": {
+ "brainstorming": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "creativity",
+ "practicality",
+ "reasonableness"
+ ]
+ },
+ "chat": {
+ "GPT": [
+ "language organization",
+ "naturalness",
+ "engagingness",
+ "fidelity"
+ ]
+ },
+ "generation": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "diversity"
+ ]
+ },
+ "open_qa": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "correctness"
+ ]
+ },
+ "roleplay": {
+ "GPT": [
+ "language organization",
+ "relevance",
+ "fidelity",
+ "creativity"
+ ]
+ }
+ }
+}
diff --git a/applications/ColossalEval/examples/gpt_evaluation/config/inference/config.json b/applications/ColossalEval/examples/gpt_evaluation/config/inference/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..7ed7491a87c5d7c619987cb09d92faad7a18074c
--- /dev/null
+++ b/applications/ColossalEval/examples/gpt_evaluation/config/inference/config.json
@@ -0,0 +1,33 @@
+{
+ "model": [
+ {
+ "name": "model name",
+ "model_class": "HuggingFaceCausalLM",
+ "parameters": {
+ "path": "path to model",
+ "model_max_length": 4096,
+ "tokenizer_path": "",
+ "tokenizer_kwargs": {
+ "trust_remote_code": true
+ },
+ "peft_path": null,
+ "model_kwargs": {
+ "torch_dtype": "torch.float32",
+ "trust_remote_code": true
+ },
+ "prompt_template": "plain",
+ "batch_size": 4
+ }
+ }
+ ],
+ "dataset": [
+ {
+ "name": "colossal",
+ "dataset_class": "ColossalDataset",
+ "debug": false,
+ "few_shot": false,
+ "path": "../../configs/gpt_evaluation/data/eval_en_examples.json",
+ "save_path": "path to save converted dataset (inference_data/colossal.json)"
+ }
+ ]
+}
diff --git a/applications/ColossalEval/examples/gpt_evaluation/eval.py b/applications/ColossalEval/examples/gpt_evaluation/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd521af59823b2564ee04ffc8ab7329f55cb5099
--- /dev/null
+++ b/applications/ColossalEval/examples/gpt_evaluation/eval.py
@@ -0,0 +1,139 @@
+import argparse
+import os
+
+import openai
+from colossal_eval.evaluate.evaluator import Evaluator
+from colossal_eval.utils import jload
+
+
+def main(args):
+ assert len(args.answer_file_list) == len(
+ args.model_name_list
+ ), "The number of answer files and model names should be equal!"
+
+ # load config
+ config = jload(args.config_file)
+
+ if config["language"] in ["cn", "en"]:
+ # get metric settings for all categories
+ metrics_per_category = {}
+ for category in config["category"].keys():
+ metrics_all = {}
+ for metric_type, metrics in config["category"][category].items():
+ metrics_all[metric_type] = metrics
+ metrics_per_category[category] = metrics_all
+
+ battle_prompt = None
+ if args.battle_prompt_file:
+ battle_prompt = jload(args.battle_prompt_file)
+
+ gpt_evaluation_prompt = None
+ if args.gpt_evaluation_prompt_file:
+ gpt_evaluation_prompt = jload(args.gpt_evaluation_prompt_file)
+
+ if len(args.model_name_list) == 2 and not battle_prompt:
+ raise Exception("No prompt file for battle provided. Please specify the prompt file for battle!")
+
+ if len(args.model_name_list) == 1 and not gpt_evaluation_prompt:
+ raise Exception(
+ "No prompt file for gpt evaluation provided. Please specify the prompt file for gpt evaluation!"
+ )
+
+ if args.gpt_model == "text-davinci-003" and args.gpt_with_reference:
+ raise Exception(
+ "GPT evaluation with reference is not supported for text-davinci-003. You should specify chat models such as gpt-3.5-turbo or gpt-4."
+ )
+
+ # initialize evaluator
+ evaluator = Evaluator(
+ metrics_per_category,
+ battle_prompt,
+ gpt_evaluation_prompt,
+ args.gpt_model,
+ config["language"],
+ args.gpt_with_reference,
+ )
+ if len(args.model_name_list) == 2:
+ answers_1 = jload(args.answer_file_list[0])
+ answers_2 = jload(args.answer_file_list[1])
+
+ answers1 = []
+ for category, value in answers_1.items():
+ answers1.extend(value["data"])
+
+ answers2 = []
+ for category, value in answers_2.items():
+ answers2.extend(value["data"])
+
+ assert len(answers1) == len(answers2), "The number of answers for two models should be equal!"
+
+ evaluator.battle(answers1=answers1, answers2=answers2)
+ evaluator.save(args.save_path, args.model_name_list)
+ elif len(args.model_name_list) == 1:
+ targets = jload(args.target_file)
+ answers = jload(args.answer_file_list[0])
+
+ references = []
+ for category, value in targets["test"].items():
+ references.extend(value["data"])
+
+ predictions = []
+ for category, value in answers.items():
+ predictions.extend(value["data"])
+
+ assert len(references) == len(
+ predictions
+ ), "The number of target answers and model answers should be equal!"
+
+ evaluator.evaluate(
+ answers=predictions, targets=references, save_path=args.save_path, model_name=args.model_name_list[0]
+ )
+ evaluator.save(args.save_path, args.model_name_list)
+ else:
+ raise ValueError("Unsupported number of answer files and model names!")
+ else:
+ raise ValueError(f'Unsupported language {config["language"]}!')
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="ColossalAI LLM evaluation pipeline.")
+ parser.add_argument(
+ "--config_file", type=str, default=None, required=True, help="path to the file of target results"
+ )
+ parser.add_argument("--battle_prompt_file", type=str, default=None, help="path to the prompt file for battle")
+ parser.add_argument(
+ "--gpt_evaluation_prompt_file", type=str, default=None, help="path to the prompt file for gpt evaluation"
+ )
+ parser.add_argument("--target_file", type=str, default=None, help="path to the target answer (ground truth) file")
+ parser.add_argument(
+ "--answer_file_list",
+ type=str,
+ nargs="+",
+ default=[],
+ required=True,
+ help="path to the answer files of at most 2 models",
+ )
+ parser.add_argument(
+ "--model_name_list", type=str, nargs="+", default=[], required=True, help="the names of at most 2 models"
+ )
+ parser.add_argument(
+ "--gpt_model",
+ default="gpt-3.5-turbo-16k",
+ choices=["text-davinci-003", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4"],
+ help="which GPT model to use for evaluation",
+ )
+ parser.add_argument(
+ "--gpt_with_reference",
+ default=False,
+ action="store_true",
+ help="whether to include reference answer in gpt evaluation",
+ )
+ parser.add_argument("--save_path", type=str, default="results", help="path to save evaluation results")
+ parser.add_argument("--openai_key", type=str, default=None, required=True, help="Your openai key")
+ args = parser.parse_args()
+
+ if args.openai_key is not None:
+ os.environ["OPENAI_API_KEY"] = args.openai_key
+ openai.api_key = os.getenv("OPENAI_API_KEY")
+
+ main(args)
diff --git a/applications/ColossalEval/examples/gpt_evaluation/eval.sh b/applications/ColossalEval/examples/gpt_evaluation/eval.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f5729e6ee5c7249aa9af842c171008ee1e35ace2
--- /dev/null
+++ b/applications/ColossalEval/examples/gpt_evaluation/eval.sh
@@ -0,0 +1,9 @@
+python eval.py \
+ --config_file "path to the config file" \
+ --battle_prompt_file "path to the prompt file for battle" \
+ --gpt_evaluation_prompt_file "path to the prompt file for gpt evaluation" \
+ --target_file "path to the target answer file" \
+ --answer_file_list "path to the answer files of at most 2 models" \
+ --model_name_list "the names of at most 2 models" \
+ --save_path "path to save results" \
+ --openai_key "your openai key" \
diff --git a/applications/ColossalEval/examples/gpt_evaluation/inference.py b/applications/ColossalEval/examples/gpt_evaluation/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..657fc33bf1ef44e54a49a43f1159b177549bf0a6
--- /dev/null
+++ b/applications/ColossalEval/examples/gpt_evaluation/inference.py
@@ -0,0 +1,171 @@
+import argparse
+import copy
+import os
+from typing import Dict, List
+
+import torch
+import torch.distributed as dist
+from colossal_eval import dataset, models, utils
+
+import colossalai
+from colossalai.logging import get_dist_logger
+
+logger = get_dist_logger()
+
+
+def rm_and_merge(world_size: int, save_path: str, model_names: List[str], dataset_names: Dict[str, List]) -> None:
+ """
+ Remove inference result per rank and merge them into one file.
+
+ Args:
+ world_size: Number of processes for inference.
+ save_path: The folder for storing inference results.
+ model_names: Names of models for inference.
+ dataset_names: Names of dataset for inference.
+
+ """
+
+ for model_name in model_names:
+ for dataset_name, categories in dataset_names.items():
+ all_answers = {}
+ for category in categories:
+ all_answers[category] = {"data": []}
+ answers = {"data": []}
+
+ for r in range(world_size):
+ directory = os.path.join(
+ save_path, model_name, f"{dataset_name}_{category}_inference_results_rank{r}.json"
+ )
+ if not os.path.exists(directory):
+ raise Exception(
+ f"Directory {directory} not found. There may be an error during inference time."
+ )
+ else:
+ rank_answers = utils.jload(directory)
+ answers["data"].extend(rank_answers["data"])
+ answers["inference_kwargs"] = rank_answers["inference_kwargs"]
+
+ for r in range(world_size):
+ try:
+ directory = os.path.join(
+ save_path, model_name, f"{dataset_name}_{category}_inference_results_rank{r}.json"
+ )
+ os.remove(directory)
+ except Exception as e:
+ print(e)
+
+ all_answers[category] = answers
+
+ logger.info(f"Save inference results of model {model_name} on dataset {dataset_name}.")
+ utils.jdump(all_answers, os.path.join(save_path, model_name, f"{dataset_name}_inference_results.json"))
+
+ logger.info(f"Save inference results of model {model_name} for all dataset.")
+ logger.info(f"Save inference results of all models for all dataset.")
+
+
+def main(args):
+ colossalai.launch_from_torch(config={}, seed=42)
+ world_size = dist.get_world_size()
+ rank = dist.get_rank()
+
+ inference_data = {}
+ debug_args = {}
+ few_shot_args = {}
+
+ config = utils.jload(args.config)
+
+ model_parameters = config["model"]
+ dataset_parameters = config["dataset"]
+
+ for dataset_parameter in dataset_parameters:
+ path = dataset_parameter["path"]
+ save_path = dataset_parameter["save_path"]
+ dataset_name = dataset_parameter["name"]
+ debug_args[dataset_name] = dataset_parameter["debug"]
+ few_shot_args[dataset_name] = dataset_parameter["few_shot"]
+
+ if not args.load_dataset:
+ if os.path.exists(save_path):
+ dataset_ = utils.jload(save_path)
+ inference_data[dataset_name] = dataset_["test"]
+ else:
+ raise Exception(
+ "Can't find the converted dataset. You may set load_dataset True to store the dataset first."
+ )
+
+ continue
+
+ dataset_class = eval(f"dataset.{dataset_parameter['dataset_class']}")
+ if not issubclass(dataset_class, dataset.BaseDataset):
+ raise ValueError(f"Dataset class {dataset_parameter['dataset_class']} is not a subclass of BaseDataset.")
+
+ dataset_ = dataset_class(path, logger, dataset_parameter["few_shot"])
+
+ dataset_.save(save_path)
+ inference_data[dataset_name] = dataset_.dataset["test"]
+
+ for model_parameter in model_parameters:
+ model_name = model_parameter["name"]
+ model_class = eval(f"models.{model_parameter['model_class']}")
+ paramerters = model_parameter["parameters"]
+ paramerters.update({"logger": logger})
+ paramerters.update({"prompt_template": utils.prompt_templates[paramerters["prompt_template"]]})
+
+ model_ = model_class(**paramerters)
+ if not issubclass(model_class, models.BaseModel):
+ raise ValueError(f"Model class {model_parameter['model_class']} is not a subclass of BaseModel.")
+
+ for dataset_name, split_data in inference_data.items():
+ start = 0
+ for category, category_data in split_data.items():
+ if few_shot_args[dataset_name] and category_data["inference_kwargs"].get("few_shot_data", None) is None:
+ raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!")
+
+ answers_to_dump = copy.deepcopy(category_data)
+ partition_size = len(category_data["data"]) // world_size
+ redundant = len(category_data["data"]) % world_size
+
+ # Ensure that the amount of data for inference is as consistent as possible across different processes.
+ lengths = [partition_size for _ in range(world_size)]
+ for j in range(redundant):
+ lengths[(j + start) % world_size] += 1
+
+ start = (start + redundant) % world_size
+
+ questions = category_data["data"][sum(lengths[0:rank]) : sum(lengths[0:rank]) + lengths[rank]]
+
+ answers_per_rank = model_.inference(
+ questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name]
+ )
+
+ answers_to_dump["data"] = answers_per_rank
+
+ utils.jdump(
+ answers_to_dump,
+ os.path.join(
+ args.inference_save_path,
+ model_name,
+ f"{dataset_name}_{category}_inference_results_rank{rank}.json",
+ ),
+ )
+
+ logger.info(f"Rank {rank} peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB")
+
+ del model_
+ torch.cuda.empty_cache()
+
+ dist.barrier()
+ if rank == 0:
+ model_names = [model_parameter["name"] for model_parameter in model_parameters]
+ dataset_names = {key: list(inference_data[key].keys()) for key in inference_data}
+ rm_and_merge(world_size, args.inference_save_path, model_names, dataset_names)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="ColossalEval inference process.")
+ parser.add_argument("--config", type=str, default=None, required=True, help="path to config file")
+ parser.add_argument("--load_dataset", default=False, action="store_true")
+ parser.add_argument("--inference_save_path", type=str, default=None, help="path to save inference results")
+ args = parser.parse_args()
+
+ main(args)
diff --git a/applications/ColossalEval/examples/gpt_evaluation/inference.sh b/applications/ColossalEval/examples/gpt_evaluation/inference.sh
new file mode 100644
index 0000000000000000000000000000000000000000..15f9afd560454a61ae719b1ac5cf7c3882ed3a99
--- /dev/null
+++ b/applications/ColossalEval/examples/gpt_evaluation/inference.sh
@@ -0,0 +1,4 @@
+torchrun --nproc_per_node=1 inference.py \
+ --config "path to config file" \
+ --load_dataset \
+ --inference_save_path "path to save inference results"
diff --git a/applications/ColossalEval/requirements.txt b/applications/ColossalEval/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c110606e0303bd7187d2e5a9fea90e4386a9bc2e
--- /dev/null
+++ b/applications/ColossalEval/requirements.txt
@@ -0,0 +1,12 @@
+transformers>=4.32.0
+colossalai>=0.3.1
+peft
+tabulate
+jieba
+fuzzywuzzy
+rouge
+openai
+matplotlib
+pandas
+seaborn
+scikit-learn
diff --git a/applications/ColossalEval/setup.py b/applications/ColossalEval/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f7b1bb5c42e75e2e8bdc736ee0c2da69a372f80
--- /dev/null
+++ b/applications/ColossalEval/setup.py
@@ -0,0 +1,31 @@
+from setuptools import find_packages, setup
+
+
+def fetch_requirements(path):
+ with open(path, "r") as fd:
+ return [r.strip() for r in fd.readlines()]
+
+
+def fetch_readme():
+ with open("README.md", encoding="utf-8") as f:
+ return f.read()
+
+
+setup(
+ name="colossal_eval",
+ version="0.0.1",
+ packages=find_packages(exclude=["examples", "*.egg-info"]),
+ description="Colossal-AI LLM-Evaluation Framework",
+ long_description=fetch_readme(),
+ long_description_content_type="text/markdown",
+ license="Apache Software License 2.0",
+ url="https://github.com/hpcaitech/LLM-Evaluation",
+ install_requires=fetch_requirements("requirements.txt"),
+ python_requires=">=3.6",
+ classifiers=[
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: Apache Software License",
+ "Environment :: GPU :: NVIDIA CUDA",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ ],
+)
diff --git a/applications/README.md b/applications/README.md
index cd0435aae199aade3981c1af6502cf42d0d8578e..f5078e06a73b41300dc1b856711129f885f65883 100644
--- a/applications/README.md
+++ b/applications/README.md
@@ -4,8 +4,10 @@ This directory contains the applications that are powered by Colossal-AI.
The list of applications include:
-- [X] [Chatbot](./Chat/README.md)
-- [X] [FastFold](https://github.com/hpcaitech/FastFold): Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters
+- [X] [Colossal-LLaMA-2](./Colossal-LLaMA-2/): Continual Pre-training of LLaMA-2.
+- [X] [ColossalEval](./ColossalEval): Evaluation Pipeline for LLMs.
+- [X] [ColossalChat](./Chat/README.md): Replication of ChatGPT with RLHF.
+- [X] [FastFold](https://github.com/hpcaitech/FastFold): Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters.
> Please note that the `Chatbot` application is migrated from the original `ChatGPT` folder.
diff --git a/colossalai/__init__.py b/colossalai/__init__.py
index f859161f78108e4c8b5cfba89e12026fee28da43..7da55590305b252ecc568bb91656f6d3ab38cdc5 100644
--- a/colossalai/__init__.py
+++ b/colossalai/__init__.py
@@ -1,11 +1,4 @@
-from .initialize import (
- get_default_parser,
- initialize,
- launch,
- launch_from_openmpi,
- launch_from_slurm,
- launch_from_torch,
-)
+from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch
try:
# .version will be created by setup.py
@@ -13,5 +6,7 @@ try:
except ModuleNotFoundError:
# this will only happen if the user did not run `pip install`
# and directly set PYTHONPATH to use Colossal-AI which is a bad practice
- __version__ = '0.0.0'
- print('please install Colossal-AI from https://www.colossalai.org/download or from source')
+ __version__ = "0.0.0"
+ print("please install Colossal-AI from https://www.colossalai.org/download or from source")
+
+__all__ = ["launch", "launch_from_openmpi", "launch_from_slurm", "launch_from_torch", "__version__"]
diff --git a/tests/test_layers/test_2d/checks_2d/__init__.py b/colossalai/_analyzer/__init__.py
similarity index 100%
rename from tests/test_layers/test_2d/checks_2d/__init__.py
rename to colossalai/_analyzer/__init__.py
diff --git a/colossalai/_analyzer/_subclasses/_meta_registration.py b/colossalai/_analyzer/_subclasses/_meta_registration.py
index 4049be79c70fc1d9c33807d74c2a00fe17c05acd..e8ba88b0406dd8a6ddd9dc00398ccc2b61390d6c 100644
--- a/colossalai/_analyzer/_subclasses/_meta_registration.py
+++ b/colossalai/_analyzer/_subclasses/_meta_registration.py
@@ -3,7 +3,7 @@
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
# for more meta_registrations
-from typing import Callable, List, Optional, Tuple, Union
+from typing import List, Optional, Union
import torch
from packaging import version
@@ -24,25 +24,23 @@ orig_empty_like = torch.empty_like
def new(*args, **kwargs):
- return orig_empty(*args, **kwargs, device=torch.device('meta'))
+ return orig_empty(*args, **kwargs, device=torch.device("meta"))
def new_strided(*args, **kwargs):
- return orig_empty_strided(*args, **kwargs, device=torch.device('meta'))
+ return orig_empty_strided(*args, **kwargs, device=torch.device("meta"))
def new_like(*args, **kwargs):
- return orig_empty_like(*args, **kwargs, device=torch.device('meta'))
+ return orig_empty_like(*args, **kwargs, device=torch.device("meta"))
def register_meta(op, register_dispatcher=True):
-
def wrapper(f):
-
def add_func(op):
meta_table[op] = f
if register_dispatcher:
- name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__)
+ name = op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__
try:
meta_lib.impl(name, f)
except:
@@ -54,7 +52,7 @@ def register_meta(op, register_dispatcher=True):
return wrapper
-if version.parse(torch.__version__) >= version.parse('1.12.0'):
+if version.parse(torch.__version__) >= version.parse("1.12.0"):
# ============================== Convolutions ======================================
# https://github.com/pytorch/pytorch/pull/79834
@register_meta(aten.convolution.default)
@@ -69,7 +67,6 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
output_padding: List[int],
groups: int,
):
-
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
@@ -146,7 +143,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
kernel_size[i],
stride[i],
output_padding_list[i],
- ))
+ )
+ )
else:
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
return ret_shape
@@ -180,19 +178,39 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
mem_fmt = pick_memory_format()
- out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
+ out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
return out
@register_meta(aten._convolution.default)
- def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
- padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
- *extra_args):
+ def meta__conv(
+ input_tensor: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor,
+ stride: List[int],
+ padding: List[int],
+ dilation: List[int],
+ is_transposed: bool,
+ output_padding: List[int],
+ groups: int,
+ *extra_args,
+ ):
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
return out
@register_meta(aten.convolution_backward.default)
- def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
- padding, dilation, transposed, output_padding, groups, output_mask):
+ def meta_conv_backward(
+ grad_output: torch.Tensor,
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ bias_sizes,
+ stride,
+ padding,
+ dilation,
+ transposed,
+ output_padding,
+ groups,
+ output_mask,
+ ):
return new_like(input), new_like(weight), new((bias_sizes))
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
@@ -224,7 +242,6 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
batch_sizes,
dropout_state,
):
-
is_input_packed = len(batch_sizes) != 0
if is_input_packed:
seq_length = len(batch_sizes)
@@ -240,8 +257,11 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
if is_input_packed:
out_shape = [batch_sizes_sum, out_size * num_directions]
else:
- out_shape = ([mini_batch, seq_length, out_size *
- num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
+ out_shape = (
+ [mini_batch, seq_length, out_size * num_directions]
+ if batch_first
+ else [seq_length, mini_batch, out_size * num_directions]
+ )
output = input.new_empty(out_shape)
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
@@ -257,15 +277,21 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
@register_meta(aten._cudnn_rnn_backward.default)
- def meta_cudnn_rnn_backward(input: torch.Tensor,
- weight: torch.Tensor,
- weight_stride0: int,
- hx: torch.Tensor,
- cx: Optional[torch.Tensor] = None,
- *args,
- **kwargs):
- return new_like(input), new_like(weight), new_like(hx), new_like(cx) if cx is not None else new(
- ()) # (grad_input, grad_weight, grad_hx, grad_cx)
+ def meta_cudnn_rnn_backward(
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ weight_stride0: int,
+ hx: torch.Tensor,
+ cx: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ):
+ return (
+ new_like(input),
+ new_like(weight),
+ new_like(hx),
+ new_like(cx) if cx is not None else new(()),
+ ) # (grad_input, grad_weight, grad_hx, grad_cx)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
# ============================== Activations =======================================
@@ -278,7 +304,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten.hardtanh_backward.default,
]
- if version.parse(torch.__version__) < version.parse('2.0.0'):
+ if version.parse(torch.__version__) < version.parse("2.0.0"):
_unregistered_ewise += [
aten.prelu_backward.default,
]
@@ -296,37 +322,61 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.native_batch_norm_backward.default)
- def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
- save_mean, save_invstd, train, eps, output_mask):
- return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
+ def meta_bn_backward(
+ dY: torch.Tensor,
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ running_mean,
+ running_var,
+ save_mean,
+ save_invstd,
+ train,
+ eps,
+ output_mask,
+ ):
+ return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.cudnn_batch_norm.default)
def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
n_input = input.size(1)
- return new_like(input), new((n_input)), new((n_input)), new(
- (0), dtype=torch.uint8) # (output, running_mean, running_var, reserve)
+ return (
+ new_like(input),
+ new((n_input)),
+ new((n_input)),
+ new((0), dtype=torch.uint8),
+ ) # (output, running_mean, running_var, reserve)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
# NB: CuDNN only implements the backward algorithm for batchnorm
# in training mode (evaluation mode batchnorm has a different algorithm),
# which is why this doesn't accept a 'training' parameter.
@register_meta(aten.cudnn_batch_norm_backward.default)
- def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
- save_mean, save_invstd, eps, reserve):
- return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
+ def meta_cudnn_bn_backward(
+ dY: torch.Tensor,
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ running_mean,
+ running_var,
+ save_mean,
+ save_invstd,
+ eps,
+ reserve,
+ ):
+ return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm.default)
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
bs, n_input = input.size(0), input.size(1)
- return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
+ return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm_backward.default)
- def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
- grad_input_mask):
- return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
+ def meta_ln_backward(
+ dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, grad_input_mask
+ ):
+ return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
# ================================== Misc ==========================================
# Maybe incorrect
@@ -355,8 +405,9 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
@register_meta(aten.embedding_dense_backward.default)
- def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
- scale_grad_by_freq):
+ def meta_embedding_dense_backward(
+ grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, scale_grad_by_freq
+ ):
return new((num_weights, grad_output.size(-1)), dtype=grad_output.dtype, layout=grad_output.layout)
# ============================== Dropout ===========================================
@@ -364,14 +415,14 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
@register_meta(aten.native_dropout.default)
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
# notice that mask is bool
- return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
+ return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout_backward.default)
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
- return new_like(grad) # (grad_in)
+ return new_like(grad) # (grad_in)
- if version.parse(torch.__version__) < version.parse('1.13.0'):
+ if version.parse(torch.__version__) < version.parse("1.13.0"):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@register_meta(aten.eye.m_out)
def meta_eye(n: int, m: int, out: torch.Tensor):
@@ -385,24 +436,28 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
result: List[Optional[torch.Tensor]] = []
for i, index in enumerate(indices):
if index is not None:
- assert index.dtype in [torch.long, torch.int8, torch.bool],\
- "tensors used as indices must be long, byte or bool tensors"
+ assert index.dtype in [
+ torch.long,
+ torch.int8,
+ torch.bool,
+ ], "tensors used as indices must be long, byte or bool tensors"
if index.dtype in [torch.int8, torch.bool]:
nonzero = index.nonzero()
k = len(result)
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
for j in range(index.ndim):
- assert index.shape[j] == self.shape[
- k +
- j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
+ assert (
+ index.shape[j] == self.shape[k + j]
+ ), f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
result.append(nonzero.select(1, j))
else:
result.append(index)
else:
result.append(index)
indices = result
- assert len(
- indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
+ assert (
+ len(indices) <= self.ndim
+ ), f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
# expand_outplace
import torch._refs as refs
diff --git a/colossalai/_analyzer/_subclasses/_monkey_patch.py b/colossalai/_analyzer/_subclasses/_monkey_patch.py
index b3ec98f0811f265e370c638369269f73e589dfd6..503981409ccaced2b6371b78c9a8e9bf77e41c06 100644
--- a/colossalai/_analyzer/_subclasses/_monkey_patch.py
+++ b/colossalai/_analyzer/_subclasses/_monkey_patch.py
@@ -1,5 +1,4 @@
import torch
-import torch.distributed as dist
from packaging import version
__all__ = [
@@ -48,7 +47,7 @@ _DistCommMethod = [
"scatter",
]
-if version.parse(torch.__version__) >= version.parse('1.12.0'):
+if version.parse(torch.__version__) >= version.parse("1.12.0"):
aten = torch.ops.aten
# TODO: dive deep here
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
diff --git a/colossalai/_analyzer/_subclasses/flop_tensor.py b/colossalai/_analyzer/_subclasses/flop_tensor.py
index 59991dc5091254631dd4c641c528b09aaf97c285..9d52c5593bb8b03c3253e35aa8c87b33ca19ebd4 100644
--- a/colossalai/_analyzer/_subclasses/flop_tensor.py
+++ b/colossalai/_analyzer/_subclasses/flop_tensor.py
@@ -8,7 +8,7 @@ from contextlib import contextmanager
from enum import Enum, auto
from functools import partial, reduce
from numbers import Number
-from typing import Any, Callable, List, Optional, Union
+from typing import Any, Callable, List, Union
import torch
from packaging import version
@@ -36,15 +36,15 @@ def _format_flops(flop):
B = 1e9
T = 1e12
if flop < K:
- return f'{flop:.2f}'
+ return f"{flop:.2f}"
elif flop < M:
- return f'{flop / K:.2f}K'
+ return f"{flop / K:.2f}K"
elif flop < B:
- return f'{flop / M:.2f}M'
+ return f"{flop / M:.2f}M"
elif flop < T:
- return f'{flop / B:.2f}B'
+ return f"{flop / B:.2f}B"
else:
- return f'{flop / T:.2f}T'
+ return f"{flop / T:.2f}T"
def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: bool = False, **kwargs) -> Number:
@@ -59,11 +59,13 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
Returns:
Number: The total number of floating point operations (FWD + BWD).
"""
- maybe_inplace = (getattr(module, 'inplace', False) or kwargs.get('inplace', False)
- or getattr(module, '__name__', None) in ('add_', 'mul_', 'div_', 'sub_'))
+ maybe_inplace = (
+ getattr(module, "inplace", False)
+ or kwargs.get("inplace", False)
+ or getattr(module, "__name__", None) in ("add_", "mul_", "div_", "sub_")
+ )
class DummyModule(torch.nn.Module):
-
def __init__(self, func):
super().__init__()
self.func = func
@@ -74,21 +76,20 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
total_flop_count = {Phase.FWD: 0, Phase.BWD: 0}
flop_counts = defaultdict(lambda: defaultdict(int))
- parents = ['Global']
+ parents = ["Global"]
module = module if isinstance(module, torch.nn.Module) else DummyModule(module)
class FlopTensor(MetaTensor):
_tensor: torch.Tensor
def __repr__(self):
- name = 'FlopParameter' if getattr(self, '_is_param', False) else 'FlopTensor'
+ name = "FlopParameter" if getattr(self, "_is_param", False) else "FlopTensor"
if self.grad_fn:
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
-
# no_dispatch is only needed if you use enable_python_mode.
# It prevents infinite recursion.
rs = super().__torch_dispatch__(func, types, args, kwargs)
@@ -115,9 +116,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
return isinstance(x, torch.Tensor) and x.is_floating_point()
def create_backwards_push(name):
-
class PushState(torch.autograd.Function):
-
@staticmethod
def forward(ctx, *args):
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
@@ -134,9 +133,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
return PushState.apply
def create_backwards_pop(name):
-
class PopState(torch.autograd.Function):
-
@staticmethod
def forward(ctx, *args):
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
@@ -147,14 +144,13 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
@staticmethod
def backward(ctx, *grad_outs):
nonlocal parents
- assert (parents[-1] == name)
+ assert parents[-1] == name
parents.pop()
return grad_outs
return PopState.apply
def enter_module(name):
-
def f(module, inputs):
nonlocal parents
parents.append(name)
@@ -165,10 +161,9 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
return f
def exit_module(name):
-
def f(module, inputs, outputs):
nonlocal parents
- assert (parents[-1] == name)
+ assert parents[-1] == name
parents.pop()
outputs = normalize_tuple(outputs)
return create_backwards_push(name)(*outputs)
@@ -189,7 +184,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
for mod in flop_counts.keys():
print(f"Module: ", mod)
for k, v in flop_counts[mod].items():
- print('\t', k, _format_flops(v))
+ print("\t", k, _format_flops(v))
print()
def detach_variables(r):
@@ -201,7 +196,7 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose:
def wrap(r):
if isinstance(r, torch.Tensor):
- data_ptr_fn = getattr(r, '_tensor', r).data_ptr
+ data_ptr_fn = getattr(r, "_tensor", r).data_ptr
r = FlopTensor(detach_variables(r))
if maybe_inplace:
r = r + 0
@@ -375,8 +370,11 @@ def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable:
# Inputs[0] contains the shape of the input.
input_shape = inputs[input_arg_index].shape
- has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index],
- 'shape') else inputs[affine_arg_index]
+ has_affine = (
+ inputs[affine_arg_index].shape is not None
+ if hasattr(inputs[affine_arg_index], "shape")
+ else inputs[affine_arg_index]
+ )
assert 2 <= len(input_shape) <= 5, input_shape
# 5 is just a rough estimate
flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4)
@@ -390,7 +388,7 @@ def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = N
training = inputs[-3]
assert isinstance(training, bool), "Signature of aten::batch_norm has changed!"
if training:
- return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
+ return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
has_affine = inputs[1].shape is not None
input_shape = reduce(operator.mul, inputs[0].shape)
return input_shape * (2 if has_affine else 1)
@@ -420,33 +418,30 @@ def ewise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> Calla
def zero_flop_jit(*args):
"""
- Count flops for zero flop layers.
+ Count flops for zero flop layers.
"""
return 0
-if version.parse(torch.__version__) >= version.parse('1.12.0'):
+if version.parse(torch.__version__) >= version.parse("1.12.0"):
flop_mapping = {
- # gemm
+ # gemm
aten.mm.default: matmul_flop_jit,
aten.matmul.default: matmul_flop_jit,
aten.addmm.default: addmm_flop_jit,
aten.bmm.default: bmm_flop_jit,
-
- # convolution
+ # convolution
aten.convolution.default: conv_flop_jit,
aten._convolution.default: conv_flop_jit,
aten.convolution_backward.default: conv_backward_flop_jit,
-
- # normalization
+ # normalization
aten.native_batch_norm.default: batchnorm_flop_jit,
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
aten.native_layer_norm.default: norm_flop_counter(2, 0),
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
-
- # pooling
+ # pooling
aten.avg_pool1d.default: ewise_flop_counter(1, 0),
aten.avg_pool2d.default: ewise_flop_counter(1, 0),
aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1),
@@ -469,7 +464,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
}
ewise_flop_aten = [
- # basic op
+ # basic op
aten.add.Tensor,
aten.add_.Tensor,
aten.div.Tensor,
@@ -485,8 +480,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten.sum.default,
aten.sum.dim_IntList,
aten.mean.dim,
-
- # activation op
+ # activation op
aten.hardswish.default,
aten.hardswish_.default,
aten.hardswish_backward.default,
@@ -509,15 +503,12 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten.tanh.default,
aten.tanh_backward.default,
aten.threshold_backward.default,
-
- # dropout
+ # dropout
aten.native_dropout.default,
aten.native_dropout_backward.default,
-
- # distribution
+ # distribution
aten.bernoulli_.float,
-
- # where
+ # where
aten.where.self,
]
for op in ewise_flop_aten:
diff --git a/colossalai/_analyzer/_subclasses/meta_tensor.py b/colossalai/_analyzer/_subclasses/meta_tensor.py
index 2bc212938ee08f1143b5dda354d9ec600dd64662..8be97d01343ebc0dd4c4f79c622a63a47bf8c769 100644
--- a/colossalai/_analyzer/_subclasses/meta_tensor.py
+++ b/colossalai/_analyzer/_subclasses/meta_tensor.py
@@ -3,12 +3,12 @@ from functools import partial
import torch
import torch.distributed as dist
-from torch.types import _bool, _device, _dtype
-from torch.utils._pytree import tree_flatten, tree_map
+from torch.types import _device
+from torch.utils._pytree import tree_map
from ._monkey_patch import _AliasATen, _DistCommMethod, _InplaceATen, _MaybeInplaceATen, _TorchOverrideableFactoryMethod
-__all__ = ['MetaTensor', 'MetaTensorMode']
+__all__ = ["MetaTensor", "MetaTensorMode"]
def register_storage(r, data_ptr_fn=None):
@@ -28,8 +28,7 @@ def _normalize_tuple(x):
# a hack of inplace execution in PyTorch
def _assert_alias(func):
- return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen # TODO: check if should be this aggressive
- )
+ return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen) # TODO: check if should be this aggressive
class MetaTensor(torch.Tensor):
@@ -65,14 +64,15 @@ class MetaTensor(torch.Tensor):
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
- device=device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')),
- requires_grad=requires_grad) # deceive the frontend for aten selections
+ device=device or (elem.device if elem.device.type != "meta" else torch.device("cpu")),
+ requires_grad=requires_grad,
+ ) # deceive the frontend for aten selections
r._tensor = elem
# ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
val = elem.data_ptr()
data_ptr_fn = lambda: val
- r._tensor = r._tensor.to(torch.device('meta'))
+ r._tensor = r._tensor.to(torch.device("meta"))
# only tensor not on `meta` should be copied to `meta`
register_storage(r._tensor, data_ptr_fn)
@@ -81,7 +81,7 @@ class MetaTensor(torch.Tensor):
return r
def __repr__(self):
- name = 'MetaParameter' if getattr(self, '_is_param', False) else 'MetaTensor'
+ name = "MetaParameter" if getattr(self, "_is_param", False) else "MetaTensor"
if self.grad_fn:
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
@@ -97,15 +97,15 @@ class MetaTensor(torch.Tensor):
x = x._tensor
elif isinstance(x, torch.Tensor):
device = x.device
- x = x.to(torch.device('meta'))
+ x = x.to(torch.device("meta"))
return x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
- if 'device' in kwargs:
- device = kwargs['device']
- kwargs['device'] = torch.device('meta')
+ if "device" in kwargs:
+ device = kwargs["device"]
+ kwargs["device"] = torch.device("meta")
# run aten for backend=CPU but actually on backend=Meta
# here we detect whether or not the execution generates a physical copy
@@ -143,21 +143,21 @@ class MetaTensor(torch.Tensor):
nonlocal device
if isinstance(x, str) or isinstance(x, _device):
device = x
- return torch.device('meta')
+ return torch.device("meta")
return x
elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
return MetaTensor(elem, device=device)
def cpu(self, *args, **kwargs):
- if self.device.type == 'cpu':
+ if self.device.type == "cpu":
return self.to(*args, **kwargs)
- return self.to(*args, device='cpu', **kwargs)
+ return self.to(*args, device="cpu", **kwargs)
def cuda(self, device=None, non_blocking=False):
if device is not None:
return self.to(device=device, non_blocking=non_blocking)
- return self.to(device='cuda:0', non_blocking=non_blocking)
+ return self.to(device="cuda:0", non_blocking=non_blocking)
def data_ptr(self):
return self._tensor.data_ptr()
@@ -177,19 +177,17 @@ class MetaTensorMode(object):
"""
def __init__(self):
- self.torch_overrides = {} # override torch.xxx
- self.dist_overrides = {} # override torch.distributed.xxx
+ self.torch_overrides = {} # override torch.xxx
+ self.dist_overrides = {} # override torch.distributed.xxx
def __enter__(self):
-
def _dummy(*args, **kwargs):
pass
def _new(*args, orig_new=torch.empty, **kwargs):
- return MetaTensor(orig_new(*args, **{
- **kwargs, 'device': 'meta'
- }),
- device=kwargs.get('device', torch.device('cpu')))
+ return MetaTensor(
+ orig_new(*args, **{**kwargs, "device": "meta"}), device=kwargs.get("device", torch.device("cpu"))
+ )
for func in _TorchOverrideableFactoryMethod:
self.torch_overrides[func] = getattr(torch, func)
diff --git a/colossalai/_analyzer/fx/codegen.py b/colossalai/_analyzer/fx/codegen.py
index 41d74f2e3719a0e56adc61a6c35684584a2d80c7..cd244b22cac0e143f990e2100a6e080708960b08 100644
--- a/colossalai/_analyzer/fx/codegen.py
+++ b/colossalai/_analyzer/fx/codegen.py
@@ -1,4 +1,4 @@
-from typing import Any, Callable, Dict, Iterable, List, Tuple
+from typing import Any, Dict, List, Tuple
import torch
@@ -22,7 +22,7 @@ from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_a
import colossalai
from colossalai.fx._compatibility import compatibility
-_register_custom_builtin('colossalai', 'import colossalai', colossalai)
+_register_custom_builtin("colossalai", "import colossalai", colossalai)
def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str:
@@ -43,17 +43,17 @@ def _gen_ckpt_usage(label, input_vars, output_vars, use_reentrant=True):
"""
Generate the checkpoint function call code text
"""
- outputs = ', '.join(output_vars)
- inputs = ', '.join(input_vars)
- return f'{outputs} = torch.utils.checkpoint.checkpoint(self.checkpoint_{label}, {inputs}, use_reentrant={use_reentrant})'
+ outputs = ", ".join(output_vars)
+ inputs = ", ".join(input_vars)
+ return f"{outputs} = torch.utils.checkpoint.checkpoint(self.checkpoint_{label}, {inputs}, use_reentrant={use_reentrant})"
def _end_of_ckpt(node: Node, ckpt_level: int) -> bool:
"""
Check if the node could end the ckpt region at `ckpt_level`
"""
- if len(node.meta['info'].activation_checkpoint) > ckpt_level:
- return node.meta['info'].activation_checkpoint[ckpt_level] is not None
+ if len(node.meta["info"].activation_checkpoint) > ckpt_level:
+ return node.meta["info"].activation_checkpoint[ckpt_level] is not None
return True
@@ -94,8 +94,8 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
current_region = None
for idx, node in enumerate(node_list):
- if len(node.meta['info'].activation_checkpoint) > ckpt_level:
- act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level]
+ if len(node.meta["info"].activation_checkpoint) > ckpt_level:
+ act_ckpt_label = node.meta["info"].activation_checkpoint[ckpt_level]
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
@@ -131,13 +131,9 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
return ckpt_regions
-def emit_ckpt_func(body,
- ckpt_func,
- node_list: List[Node],
- emit_node_func,
- delete_unused_value_func,
- ckpt_level=0,
- in_ckpt=False):
+def emit_ckpt_func(
+ body, ckpt_func, node_list: List[Node], emit_node_func, delete_unused_value_func, ckpt_level=0, in_ckpt=False
+):
"""Emit ckpt function in nested way
Args:
@@ -156,12 +152,12 @@ def emit_ckpt_func(body,
# label given by each layer, e.g. if you are currently at level (0, 1, 1)
# the label will be '0_1_1'
- label = "_".join([str(idx) for idx in node_list[0].meta['info'].activation_checkpoint[:ckpt_level + 1]])
+ label = "_".join([str(idx) for idx in node_list[0].meta["info"].activation_checkpoint[: ckpt_level + 1]])
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
- ckpt_func.append(f'{ckpt_fn_def}\n')
+ ckpt_func.append(f"{ckpt_fn_def}\n")
# if there is more level to fetch
- if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].activation_checkpoint), node_list)):
+ if ckpt_level + 1 < max(map(lambda node: len(node.meta["info"].activation_checkpoint), node_list)):
ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
@@ -174,33 +170,40 @@ def emit_ckpt_func(body,
break
if node_idx in start_idx:
- ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
- emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, delete_unused_value_func,
- ckpt_level + 1, True)
+ ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
+ emit_ckpt_func(
+ ckpt_func,
+ ckpt_func_buffer,
+ ckpt_node_list,
+ emit_node_func,
+ delete_unused_value_func,
+ ckpt_level + 1,
+ True,
+ )
node_idx += len(ckpt_node_list)
else:
node = node_list[node_idx]
emit_node_func(node, ckpt_func)
- ckpt_func[-1] = ' ' + ckpt_func[-1]
+ ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
node_idx += 1
- ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
+ ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
ckpt_func += ckpt_func_buffer
# last level
else:
for node in node_list:
emit_node_func(node, ckpt_func)
- ckpt_func[-1] = ' ' + ckpt_func[-1]
+ ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
- ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
+ ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
- usage = _gen_ckpt_usage(label, inputs, outputs, False) + '\n'
+ usage = _gen_ckpt_usage(label, inputs, outputs, False) + "\n"
if in_ckpt:
- usage = ' ' + usage
+ usage = " " + usage
body.append(usage)
@@ -229,7 +232,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# process ckpt_regions
if node_idx in start_idx:
- ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
+ ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func)
node_idx += len(ckpt_node_list)
@@ -243,7 +246,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
@compatibility(is_backward_compatible=True)
class ActivationCheckpointCodeGen(CodeGen):
-
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
free_vars: List[str] = []
body: List[str] = []
@@ -251,7 +253,7 @@ class ActivationCheckpointCodeGen(CodeGen):
wrapped_fns: Dict[str, None] = {}
# Wrap string in list to pass by reference
- maybe_return_annotation: List[str] = ['']
+ maybe_return_annotation: List[str] = [""]
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
@@ -259,7 +261,7 @@ class ActivationCheckpointCodeGen(CodeGen):
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
- if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
+ if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
@@ -281,16 +283,16 @@ class ActivationCheckpointCodeGen(CodeGen):
def type_repr(o: Any):
if o == ():
# Empty tuple is used for empty tuple type annotation Tuple[()]
- return '()'
+ return "()"
typename = _type_repr(o)
- if hasattr(o, '__origin__'):
+ if hasattr(o, "__origin__"):
# This is a generic type, e.g. typing.List[torch.Tensor]
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
origin_typename = add_global(_type_repr(origin_type), origin_type)
- if hasattr(o, '__args__'):
+ if hasattr(o, "__args__"):
# Assign global names for each of the inner type variables.
args = [type_repr(arg) for arg in o.__args__]
@@ -309,19 +311,18 @@ class ActivationCheckpointCodeGen(CodeGen):
return add_global(typename, o)
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
-
def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global.
- if isinstance(arg, tuple) and hasattr(arg, '_fields'):
+ if isinstance(arg, tuple) and hasattr(arg, "_fields"):
qualified_name = _get_qualified_name(type(arg))
global_name = add_global(qualified_name, type(arg))
return f"{global_name}{repr(tuple(arg))}"
return repr(arg)
- args_s = ', '.join(_get_repr(a) for a in args)
- kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
+ args_s = ", ".join(_get_repr(a) for a in args)
+ kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
if args_s and kwargs_s:
- return f'{args_s}, {kwargs_s}'
+ return f"{args_s}, {kwargs_s}"
return args_s or kwargs_s
# Run through reverse nodes and record the first instance of a use
@@ -347,82 +348,94 @@ class ActivationCheckpointCodeGen(CodeGen):
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
- if user.op == 'placeholder':
+ if user.op == "placeholder":
return
- if user.op == 'output':
- body.append('\n')
+ if user.op == "output":
+ body.append("\n")
return
nodes_to_delete = user_to_last_uses.get(user, [])
if len(nodes_to_delete):
- to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
- body.append(f'; {to_delete_str}\n')
+ to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"])
+ body.append(f"; {to_delete_str}\n")
else:
- body.append('\n')
+ body.append("\n")
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
- maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
- if node.op == 'placeholder':
+ maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
+ if node.op == "placeholder":
assert isinstance(node.target, str)
- maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
- free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
- raw_name = node.target.replace('*', '')
+ maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
+ free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
+ raw_name = node.target.replace("*", "")
if raw_name != repr(node):
- body.append(f'{repr(node)} = {raw_name}\n')
+ body.append(f"{repr(node)} = {raw_name}\n")
return
- elif node.op == 'call_method':
+ elif node.op == "call_method":
assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
- f'({_format_args(node.args[1:], node.kwargs)})')
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
+ f"({_format_args(node.args[1:], node.kwargs)})"
+ )
return
- elif node.op == 'call_function':
+ elif node.op == "call_function":
assert callable(node.target)
# pretty print operators
- if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
+ if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
- body.append(f'{repr(node)}{maybe_type_annotation} = '
- f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = "
+ f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
+ )
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
- if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods:
- body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; '
- f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}')
+ if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods:
+ body.append(
+ f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
+ f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
+ )
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
- if global_name == 'getattr' and \
- isinstance(node.args, tuple) and \
- isinstance(node.args[1], str) and \
- node.args[1].isidentifier() and \
- len(node.args) == 2:
+ if (
+ global_name == "getattr"
+ and isinstance(node.args, tuple)
+ and isinstance(node.args[1], str)
+ and node.args[1].isidentifier()
+ and len(node.args) == 2
+ ):
body.append(
- f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
+ f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
+ )
return
body.append(
- f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
- if node.meta.get('is_wrapped', False):
+ f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
+ )
+ if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name)
return
- elif node.op == 'call_module':
+ elif node.op == "call_module":
assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = '
- f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = "
+ f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
+ )
return
- elif node.op == 'get_attr':
+ elif node.op == "get_attr":
assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
+ body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}")
return
- elif node.op == 'output':
+ elif node.op == "output":
if node.type is not None:
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
body.append(self.generate_output(node.args[0]))
return
- raise NotImplementedError(f'node: {node.op} {node.target}')
+ raise NotImplementedError(f"node: {node.op} {node.target}")
# Modified for activation checkpointing
ckpt_func = []
@@ -432,13 +445,13 @@ class ActivationCheckpointCodeGen(CodeGen):
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
- body.append('pass\n')
+ body.append("pass\n")
if len(wrapped_fns) > 0:
- wrap_name = add_global('wrap', torch.fx.wrap)
- wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
+ wrap_name = add_global("wrap", torch.fx.wrap)
+ wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
else:
- wrap_stmts = ''
+ wrap_stmts = ""
if self._body_transformer:
body = self._body_transformer(body)
@@ -447,11 +460,11 @@ class ActivationCheckpointCodeGen(CodeGen):
add_global(name, value)
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
- prologue = ''.join(ckpt_func) + prologue
+ prologue = "".join(ckpt_func) + prologue
prologue = prologue
- code = ''.join(body)
- code = '\n'.join(' ' + line for line in code.split('\n'))
+ code = "".join(body)
+ code = "\n".join(" " + line for line in code.split("\n"))
fn_code = f"""
{wrap_stmts}
{prologue}
diff --git a/colossalai/_analyzer/fx/graph_module.py b/colossalai/_analyzer/fx/graph_module.py
index 1fdedd758c01f8fef9c0ba2e529b528527d8d527..9d3999e322b90619c5d569d48b7c202c8e167296 100644
--- a/colossalai/_analyzer/fx/graph_module.py
+++ b/colossalai/_analyzer/fx/graph_module.py
@@ -13,6 +13,7 @@ from torch.fx.graph import PythonCode
try:
from torch.fx.graph import _PyTreeCodeGen
+
SUPPORT_PT_CODEGEN = True
except ImportError:
SUPPORT_PT_CODEGEN = False
@@ -24,7 +25,6 @@ from torch.nn.modules.module import _addindent
# This is a copy of torch.fx.graph_module._WrappedCall.
# It should be removed when we stop supporting torch < 1.12.0.
class _WrappedCall:
-
def __init__(self, cls, cls_call):
self.cls = cls
self.cls_call = cls_call
@@ -50,12 +50,14 @@ class _WrappedCall:
# constituent substrings of the error message
tb_repr = traceback.format_exc()
- custom_msg = ("Call using an FX-traced Module, "
- f"line {err_lineno} of the traced Module's "
- "generated forward function:")
- before_err = "".join(all_src_lines[err_lineno - 2:err_lineno])
+ custom_msg = (
+ "Call using an FX-traced Module, "
+ f"line {err_lineno} of the traced Module's "
+ "generated forward function:"
+ )
+ before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno])
marker = "~" * err_line_len + "~~~ <--- HERE"
- err_and_after_err = "\n".join(all_src_lines[err_lineno:err_lineno + 2])
+ err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2])
# joined message
return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
@@ -65,11 +67,14 @@ class _WrappedCall:
if self.cls_call is not None:
return self.cls_call(obj, *args, **kwargs)
else:
- return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
+ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
except Exception as e:
assert e.__traceback__
- topmost_framesummary: traceback.FrameSummary = \
- traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type]
+ topmost_framesummary: traceback.FrameSummary = traceback.StackSummary.extract(
+ traceback.walk_tb(e.__traceback__)
+ )[
+ -1
+ ] # type: ignore[arg-type]
if "eval_with_key" in topmost_framesummary.filename:
print(_WrappedCall._generate_error_message(topmost_framesummary), file=sys.stderr)
raise e.with_traceback(None)
@@ -99,10 +104,9 @@ class ColoGraphModule(torch.fx.GraphModule):
code.
"""
- def __init__(self,
- root: Union[torch.nn.Module, Dict[str, Any]],
- graph: torch.fx.Graph,
- class_name: str = 'GraphModule'):
+ def __init__(
+ self, root: Union[torch.nn.Module, Dict[str, Any]], graph: torch.fx.Graph, class_name: str = "GraphModule"
+ ):
super().__init__(root, graph, class_name)
def bind(self, ckpt_def, globals):
@@ -134,7 +138,7 @@ class ColoGraphModule(torch.fx.GraphModule):
if SUPPORT_PT_CODEGEN and isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
- python_code = self._graph.python_code(root_module='self')
+ python_code = self._graph.python_code(root_module="self")
self._code = python_code.src
# To split ckpt functions code and forward code
@@ -157,8 +161,8 @@ class ColoGraphModule(torch.fx.GraphModule):
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
cls_call = cls.__call__ if "__call__" in vars(cls) else None
- if '_wrapped_call' not in vars(cls):
- cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
+ if "_wrapped_call" not in vars(cls):
+ cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
def call_wrapped(self, *args, **kwargs):
return self._wrapped_call(self, *args, **kwargs)
@@ -182,7 +186,7 @@ class ColoGraphModule(torch.fx.GraphModule):
"""
folder = Path(folder)
Path(folder).mkdir(exist_ok=True)
- torch.save(self.state_dict(), folder / 'state_dict.pt')
+ torch.save(self.state_dict(), folder / "state_dict.pt")
tab = " " * 4
# we add import colossalai here
@@ -208,10 +212,10 @@ class {module_name}(torch.nn.Module):
for module_name, module in self.named_children():
module_str = _gen_model_repr(module_name, module)
if module_str is None:
- module_file = folder / f'{module_name}.pt'
+ module_file = folder / f"{module_name}.pt"
torch.save(module, module_file)
blobified_modules.append(module_name)
- module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ')
+ module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
module_str = f"torch.load(r'{module_file}') # {module_repr}"
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
@@ -228,12 +232,14 @@ class {module_name}(torch.nn.Module):
model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
model_str += f"{_addindent(self.code, 4)}\n"
- module_file = folder / 'module.py'
+ module_file = folder / "module.py"
module_file.write_text(model_str)
- init_file = folder / '__init__.py'
- init_file.write_text('from .module import *')
+ init_file = folder / "__init__.py"
+ init_file.write_text("from .module import *")
if len(blobified_modules) > 0:
- warnings.warn("Was not able to save the following children modules as reprs -"
- f"saved as pickled files instead: {blobified_modules}")
+ warnings.warn(
+ "Was not able to save the following children modules as reprs -"
+ f"saved as pickled files instead: {blobified_modules}"
+ )
diff --git a/colossalai/_analyzer/fx/node_util.py b/colossalai/_analyzer/fx/node_util.py
index fbe8400a437eef944b6e8ac0a1126774797da031..d2671787ea63899db3aa71ff10d055a665291f18 100644
--- a/colossalai/_analyzer/fx/node_util.py
+++ b/colossalai/_analyzer/fx/node_util.py
@@ -1,9 +1,9 @@
from dataclasses import dataclass, field
-from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Union
+from typing import Dict, List, Optional, Tuple, Union
import torch
-from torch.autograd.profiler_util import _format_memory, _format_time
-from torch.fx import Graph, GraphModule, Node
+from torch.autograd.profiler_util import _format_memory
+from torch.fx import Node
from colossalai._analyzer.envs import MeshConfig
@@ -85,12 +85,12 @@ class MetaInfo:
node: Node
# directory
- mod_dir: str = ''
+ mod_dir: str = ""
# ctx[data_ptr] = Tensor
# mark the storage for ctx.save_for_backward
- global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared
- curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node
+ global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared
+ curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node
# should be updated after each graph manipulation
# ============================== Update ====================================
@@ -100,7 +100,7 @@ class MetaInfo:
inputs: Tuple[torch.Tensor] = ()
outputs: Tuple[torch.Tensor] = ()
- is_alias: Tuple[bool] = () # whether the output is an alias of input
+ is_alias: Tuple[bool] = () # whether the output is an alias of input
# compute cost
fwd_flop: Optional[int] = 0
@@ -112,29 +112,29 @@ class MetaInfo:
# should keep the same whenever manipulated
# ============================= Invariant ==================================
- activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
+ activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
to_offload: Optional[bool] = False
- sharding_spec: str = 'RR'
+ sharding_spec: str = "RR"
def __new__(cls, node: Node, **kwargs):
orig_init = cls.__init__
# if initialized, return the existing one
# should disable the __init__ function
- if node.meta.get('info', None) is not None:
+ if node.meta.get("info", None) is not None:
def _dummy(self, *args, **kwargs):
- if getattr(self, '_is_init', False):
+ if getattr(self, "_is_init", False):
self._is_init = True
orig_init(self, *args, **kwargs)
cls.__init__ = orig_init
cls.__init__ = _dummy
- return node.meta['info']
+ return node.meta["info"]
return super().__new__(cls)
def __post_init__(self):
- self.node.meta['info'] = self
+ self.node.meta["info"] = self
@property
def fwd_time(self, tflops: float = MeshConfig.TFLOPS, bandwidth: float = MeshConfig.BANDWIDTH):
@@ -188,24 +188,26 @@ class MetaInfo:
return compute_size_in_bytes(self.inputs)
def __repr__(self):
- s = f'Node {self.node.name}'
+ s = f"Node {self.node.name}"
if self.parameters:
- s += f'\n\thas parameter of size {_format_memory(self.param_size)}'
+ s += f"\n\thas parameter of size {_format_memory(self.param_size)}"
if self.buffers:
- s += f'\n\thas buffer of size {_format_memory(self.buffer_size)}'
+ s += f"\n\thas buffer of size {_format_memory(self.buffer_size)}"
if self.output_size:
- s += f'\n\thas output activation of size {_format_memory(self.output_size)}'
+ s += f"\n\thas output activation of size {_format_memory(self.output_size)}"
# if self.total_size:
# s += f'\n\thas total activation of size {_format_memory(self.total_size)}'
if self.temp_size:
- s += f'\n\thas temp activation of size {_format_memory(self.temp_size)}'
+ s += f"\n\thas temp activation of size {_format_memory(self.temp_size)}"
if self.backward_size:
- s += f'\n\thas backward activation of size {_format_memory(self.backward_size)}'
- s += f'\n\tfwd_flop = {self.fwd_flop}'\
- f'\n\tbwd_flop = {self.bwd_flop}'\
- f'\n\tfwd_comm = {self.fwd_comm}'\
- f'\n\tbwd_comm = {self.bwd_comm}'\
- f'\n\tto_recompute = {self.to_recompute}'\
- f'\n\tto_offload = {self.to_offload}'\
- f'\n\tsharding_spec = {self.sharding_spec}'
+ s += f"\n\thas backward activation of size {_format_memory(self.backward_size)}"
+ s += (
+ f"\n\tfwd_flop = {self.fwd_flop}"
+ f"\n\tbwd_flop = {self.bwd_flop}"
+ f"\n\tfwd_comm = {self.fwd_comm}"
+ f"\n\tbwd_comm = {self.bwd_comm}"
+ f"\n\tto_recompute = {self.to_recompute}"
+ f"\n\tto_offload = {self.to_offload}"
+ f"\n\tsharding_spec = {self.sharding_spec}"
+ )
return s
diff --git a/colossalai/_analyzer/fx/passes/graph_profile.py b/colossalai/_analyzer/fx/passes/graph_profile.py
index c3e760b31e96df2cd4f2de52c6309fefa9183417..158ebce219cd62c98c9fa2b064a8489d70868c39 100644
--- a/colossalai/_analyzer/fx/passes/graph_profile.py
+++ b/colossalai/_analyzer/fx/passes/graph_profile.py
@@ -1,8 +1,8 @@
-from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterator, List, Optional, Tuple
import torch
import torch.fx
-from torch.autograd.profiler_util import _format_memory, _format_time
+from torch.autograd.profiler_util import _format_memory
from torch.fx import GraphModule
from torch.fx.node import Argument, Node, Target
@@ -13,14 +13,14 @@ from colossalai._analyzer.fx.node_util import MetaInfo
def _format_flops(flops: float) -> str:
"""Returns a formatted FLOP size string"""
if flops > 1e12:
- return f'{flops / 1e12:.2f} TFLOPs'
+ return f"{flops / 1e12:.2f} TFLOPs"
elif flops > 1e9:
- return f'{flops / 1e9:.2f} GFLOPs'
+ return f"{flops / 1e9:.2f} GFLOPs"
elif flops > 1e6:
- return f'{flops / 1e6:.2f} MFLOPs'
+ return f"{flops / 1e6:.2f} MFLOPs"
elif flops > 1e3:
- return f'{flops / 1e3:.2f} kFLOPs'
- return f'{flops} FLOPs'
+ return f"{flops / 1e3:.2f} kFLOPs"
+ return f"{flops} FLOPs"
def _denormalize_tuple(t: Tuple[int, ...]) -> Tuple[int, ...]:
@@ -42,10 +42,11 @@ class GraphProfiler(torch.fx.Interpreter):
Fetch shape argument from ``ShapeProp`` without re-executing
the ``GraphModule`` from scratch.
"""
+
_profileable = [
- 'call_function',
- 'call_module',
- 'call_method',
+ "call_function",
+ "call_module",
+ "call_method",
]
def __init__(self, module: GraphModule, garbage_collect_values: bool = True):
@@ -77,14 +78,13 @@ class GraphProfiler(torch.fx.Interpreter):
self.args_iter: Iterator[Any] = iter(args)
for node in self.module.graph.nodes:
-
- self.run_node(node) # No need to store.
+ self.run_node(node) # No need to store.
if self.garbage_collect_values:
for to_delete in self.user_to_last_uses.get(node, []):
del self.env[to_delete]
- if node.op == 'output':
+ if node.op == "output":
output_val = self.env[node]
return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val
@@ -133,9 +133,11 @@ class GraphProfiler(torch.fx.Interpreter):
try:
from tabulate import tabulate
except ImportError:
- print("`summary` relies on the library `tabulate`, "
- "which could not be found on this machine. Run `pip "
- "install tabulate` to install the library.")
+ print(
+ "`summary` relies on the library `tabulate`, "
+ "which could not be found on this machine. Run `pip "
+ "install tabulate` to install the library."
+ )
# Build up a list of summary information for each node
node_summaries: List[List[Any]] = []
@@ -145,36 +147,38 @@ class GraphProfiler(torch.fx.Interpreter):
node: Node
n_info = MetaInfo(node)
last_n_info = last_n_info or n_info
- node_summaries.append([
- node.op,
- str(node),
- _format_memory(n_info.accumulate_size),
- _format_memory(n_info.accumulate_size - last_n_info.accumulate_size),
- _format_memory(n_info.output_size),
- _format_memory(n_info.temp_size),
- _format_memory(n_info.param_size),
- _format_memory(n_info.backward_size),
- _format_flops(n_info.fwd_flop),
- _format_flops(n_info.bwd_flop),
- ])
+ node_summaries.append(
+ [
+ node.op,
+ str(node),
+ _format_memory(n_info.accumulate_size),
+ _format_memory(n_info.accumulate_size - last_n_info.accumulate_size),
+ _format_memory(n_info.output_size),
+ _format_memory(n_info.temp_size),
+ _format_memory(n_info.param_size),
+ _format_memory(n_info.backward_size),
+ _format_flops(n_info.fwd_flop),
+ _format_flops(n_info.bwd_flop),
+ ]
+ )
last_n_info = n_info
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers: List[str] = [
- 'Op type',
- 'Op',
- 'Accumulate size',
- 'Incremental size',
- 'Output size',
- 'Temp size',
- 'Param size',
- 'Backward size',
- 'Fwd FLOPs',
- 'Bwd FLOPs',
+ "Op type",
+ "Op",
+ "Accumulate size",
+ "Incremental size",
+ "Output size",
+ "Temp size",
+ "Param size",
+ "Backward size",
+ "Fwd FLOPs",
+ "Bwd FLOPs",
]
- return tabulate(node_summaries, headers=headers, stralign='right')
+ return tabulate(node_summaries, headers=headers, stralign="right")
class CommunicationProfiler(GraphProfiler):
@@ -222,6 +226,7 @@ class FlopProfiler(GraphProfiler):
>>> def my_fn_flop_count_impl(*args, **kwargs):
>>> return 0, 0
"""
+
_custom_flop_count_impl = {}
def run_node(self, n: torch.fx.Node) -> Any:
@@ -246,11 +251,13 @@ class FlopProfiler(GraphProfiler):
(
n_info.fwd_flop,
n_info.bwd_flop,
- ) = getattr(self, n.op)(n.target, args, kwargs)
+ ) = getattr(
+ self, n.op
+ )(n.target, args, kwargs)
except Exception as e:
raise RuntimeError(
- f'Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. '
- f'Please refer to function\'s docstring to register the relevant profile_impl for this node!'
+ f"Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. "
+ f"Please refer to function's docstring to register the relevant profile_impl for this node!"
) from e
# retain the autograd graph
@@ -259,7 +266,7 @@ class FlopProfiler(GraphProfiler):
return _denormalize_tuple(n_info.outputs)
- def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node and return the profiling result.
Dispatch to ``_custom_flop_count_impl`` if ``call_function`` should be
@@ -283,7 +290,7 @@ class FlopProfiler(GraphProfiler):
else:
return flop_count(target, *args, **kwargs)
- def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node and return the profiling result.
@@ -301,7 +308,7 @@ class FlopProfiler(GraphProfiler):
assert isinstance(target, str)
return flop_count(getattr(torch.Tensor, target), *args, **kwargs)
- def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node and return the profiling result.
@@ -336,9 +343,10 @@ def graph_profile_pass(module: GraphModule, *args, verbose=False) -> GraphModule
Returns:
GraphModule: The same GraphModule with profiling information
"""
- for profiler_cls in (FlopProfiler,
- # CommunicationProfiler, # TODO: add communication profiling
- ):
+ for profiler_cls in (
+ FlopProfiler,
+ # CommunicationProfiler, # TODO: add communication profiling
+ ):
profiler = profiler_cls(module)
profiler.propagate(*args, device=_current_device(module))
diff --git a/colossalai/_analyzer/fx/passes/shape_prop.py b/colossalai/_analyzer/fx/passes/shape_prop.py
index 23e83013e02fd6a60d016f9b5c86a17d85cdaf81..8d44f1d4b59d33bb5c7bed3cb20b44448198c917 100644
--- a/colossalai/_analyzer/fx/passes/shape_prop.py
+++ b/colossalai/_analyzer/fx/passes/shape_prop.py
@@ -54,7 +54,7 @@ def _current_device(module):
try:
return next(module.parameters()).device
except StopIteration:
- return torch.device('cpu')
+ return torch.device("cpu")
@compatibility(is_backward_compatible=False)
@@ -90,6 +90,7 @@ class ShapeProp(torch.fx.Interpreter):
>>> # do something here
>>> return torch.empty(output_shape, device=output_device)
"""
+
_custom_dispatch_func = {}
_mode = MetaTensorMode()
@@ -115,15 +116,14 @@ class ShapeProp(torch.fx.Interpreter):
r = getattr(self, n.op)(n.target, args, kwargs)
def unwrap_fn(elem):
-
def _convert_meta(t: torch.Tensor):
- if t.device == 'meta':
+ if t.device == "meta":
return t
else:
- return t.to('meta')
+ return t.to("meta")
if isinstance(elem, MetaTensor):
- if getattr(self, '_is_param', False):
+ if getattr(self, "_is_param", False):
return torch.nn.Parameter(_convert_meta(elem._tensor))
return _convert_meta(elem._tensor)
@@ -139,21 +139,24 @@ class ShapeProp(torch.fx.Interpreter):
n_info = MetaInfo(n)
n_info.outputs = _normalize_tuple(r)
- if n.op == 'call_module':
+ if n.op == "call_module":
submod = self.fetch_attr(n.target)
n_info.parameters.update({k: MetaTensor(v) for k, v in submod.named_parameters()})
n_info.buffers.update({k: MetaTensor(v) for k, v in submod.named_buffers()})
else:
- n_info.parameters.update({
- k.name: MetaTensor(v)
- for k, v in zip(n.args, args)
- if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter)
- })
+ n_info.parameters.update(
+ {
+ k.name: MetaTensor(v)
+ for k, v in zip(n.args, args)
+ if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter)
+ }
+ )
n_info.parameters.update({k: MetaTensor(v) for k, v in kwargs.items() if isinstance(v, torch.nn.Parameter)})
- n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \
- tuple(v for v in kwargs.values() if is_pure_tensor(v))
+ n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + tuple(
+ v for v in kwargs.values() if is_pure_tensor(v)
+ )
# align with SPMD
if isinstance(r, (tuple, list)):
@@ -168,7 +171,7 @@ class ShapeProp(torch.fx.Interpreter):
n_info.is_alias = _normalize_tuple(tree_map(crit, n_info.outputs))
return r
- def call_function(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_function(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node and return the result.
If the target of ``Node`` is registered with ``@register_shape_impl``,
@@ -197,7 +200,7 @@ class ShapeProp(torch.fx.Interpreter):
else:
return res
- def call_method(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_method(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node and return the result.
@@ -218,7 +221,8 @@ class ShapeProp(torch.fx.Interpreter):
convert_to_parameter = False
if target_method in (torch.Tensor.view, torch.Tensor.transpose) and isinstance(
- args[0], torch.nn.parameter.Parameter):
+ args[0], torch.nn.parameter.Parameter
+ ):
convert_to_parameter = True
# Execute the method and return the result
assert isinstance(target, str)
diff --git a/colossalai/_analyzer/fx/symbolic_profile.py b/colossalai/_analyzer/fx/symbolic_profile.py
index dd7f22c6c98a0d47c946935a991a1f31d1052734..5732a6665f7843a6bd21d11abc002f088a059c44 100644
--- a/colossalai/_analyzer/fx/symbolic_profile.py
+++ b/colossalai/_analyzer/fx/symbolic_profile.py
@@ -1,5 +1,3 @@
-import torch
-import torch.fx
from torch.fx import GraphModule
from .passes import ShapeProp, graph_profile_pass, shape_prop_pass
@@ -7,7 +5,6 @@ from .passes.graph_profile import FlopProfiler
def register_flop_count_impl(func):
-
def wrapper(impl):
FlopProfiler._custom_flop_count_impl[func] = impl
return impl
@@ -16,7 +13,6 @@ def register_flop_count_impl(func):
def register_shape_impl(func):
-
def wrapper(impl):
ShapeProp._custom_dispatch_func[func] = impl
return impl
diff --git a/colossalai/_analyzer/fx/tracer/bias_addition.py b/colossalai/_analyzer/fx/tracer/bias_addition.py
index 1e75b47ca5b038aadb9c9bf0779bc3565d91bead..b8b83282b42c24443917a85cf968a22b54cf8156 100644
--- a/colossalai/_analyzer/fx/tracer/bias_addition.py
+++ b/colossalai/_analyzer/fx/tracer/bias_addition.py
@@ -12,7 +12,7 @@ from .tracer import register_tracer_impl
__all__ = []
-@register_tracer_impl(F.linear, name='_bias_addition_impl')
+@register_tracer_impl(F.linear, name="_bias_addition_impl")
def linear_impl(input, weight, bias=None):
if bias is None:
return F.linear(input, weight)
@@ -20,116 +20,130 @@ def linear_impl(input, weight, bias=None):
return F.linear(input, weight) + bias
-@register_tracer_impl(F.conv1d, name='_bias_addition_impl')
+@register_tracer_impl(F.conv1d, name="_bias_addition_impl")
def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1):
if bias is None:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
- (-1, 1))
+ (-1, 1)
+ )
-@register_tracer_impl(F.conv2d, name='_bias_addition_impl')
+@register_tracer_impl(F.conv2d, name="_bias_addition_impl")
def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1):
if bias is None:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
- (-1, 1, 1))
+ (-1, 1, 1)
+ )
-@register_tracer_impl(F.conv3d, name='_bias_addition_impl')
+@register_tracer_impl(F.conv3d, name="_bias_addition_impl")
def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1):
if bias is None:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
- (-1, 1, 1, 1))
-
-
-@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
-def conv_transpose1d_impl(input,
- weight,
- bias=None,
- stride=_single(1),
- padding=_single(0),
- output_padding=_single(0),
- groups=1,
- dilation=_single(1)):
+ (-1, 1, 1, 1)
+ )
+
+
+@register_tracer_impl(F.conv_transpose1d, name="_bias_addition_impl")
+def conv_transpose1d_impl(
+ input,
+ weight,
+ bias=None,
+ stride=_single(1),
+ padding=_single(0),
+ output_padding=_single(0),
+ groups=1,
+ dilation=_single(1),
+):
if bias is None:
- return F.conv_transpose1d(input,
- weight,
- stride=stride,
- padding=padding,
- output_padding=output_padding,
- groups=groups,
- dilation=dilation)
+ return F.conv_transpose1d(
+ input,
+ weight,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation,
+ )
else:
- return F.conv_transpose1d(input,
- weight,
- stride=stride,
- padding=padding,
- output_padding=output_padding,
- groups=groups,
- dilation=dilation) + bias.reshape((-1, 1))
-
-
-@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
-def conv_transpose2d_impl(input,
- weight,
- bias=None,
- stride=_pair(1),
- padding=_pair(0),
- output_padding=_pair(0),
- groups=1,
- dilation=_pair(1)):
+ return F.conv_transpose1d(
+ input,
+ weight,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation,
+ ) + bias.reshape((-1, 1))
+
+
+@register_tracer_impl(F.conv_transpose2d, name="_bias_addition_impl")
+def conv_transpose2d_impl(
+ input, weight, bias=None, stride=_pair(1), padding=_pair(0), output_padding=_pair(0), groups=1, dilation=_pair(1)
+):
if bias is None:
- return F.conv_transpose2d(input,
- weight,
- stride=stride,
- padding=padding,
- output_padding=output_padding,
- groups=groups,
- dilation=dilation)
+ return F.conv_transpose2d(
+ input,
+ weight,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation,
+ )
else:
- return F.conv_transpose2d(input,
- weight,
- stride=stride,
- padding=padding,
- output_padding=output_padding,
- groups=groups,
- dilation=dilation) + bias.reshape((-1, 1, 1))
-
-
-@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
-def conv_transpose3d_impl(input,
- weight,
- bias=None,
- stride=_triple(1),
- padding=_triple(0),
- output_padding=_triple(0),
- groups=1,
- dilation=_triple(1)):
+ return F.conv_transpose2d(
+ input,
+ weight,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation,
+ ) + bias.reshape((-1, 1, 1))
+
+
+@register_tracer_impl(F.conv_transpose3d, name="_bias_addition_impl")
+def conv_transpose3d_impl(
+ input,
+ weight,
+ bias=None,
+ stride=_triple(1),
+ padding=_triple(0),
+ output_padding=_triple(0),
+ groups=1,
+ dilation=_triple(1),
+):
if bias is None:
- return F.conv_transpose3d(input,
- weight,
- stride=stride,
- padding=padding,
- output_padding=output_padding,
- groups=groups,
- dilation=dilation)
+ return F.conv_transpose3d(
+ input,
+ weight,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation,
+ )
else:
- return F.conv_transpose3d(input,
- weight,
- stride=stride,
- padding=padding,
- output_padding=output_padding,
- groups=groups,
- dilation=dilation) + bias.reshape((-1, 1, 1, 1))
-
-
-@register_tracer_impl(torch.addmm, name='_bias_addition_impl')
-@register_tracer_impl(torch.Tensor.addmm, name='_bias_addition_impl')
+ return F.conv_transpose3d(
+ input,
+ weight,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation,
+ ) + bias.reshape((-1, 1, 1, 1))
+
+
+@register_tracer_impl(torch.addmm, name="_bias_addition_impl")
+@register_tracer_impl(torch.Tensor.addmm, name="_bias_addition_impl")
def addmm_impl(input, mat1, mat2, beta=1, alpha=1):
if alpha != 1 and beta != 1:
return F.linear(mat1, mat2.transpose(0, 1)) * alpha + input * beta
@@ -141,8 +155,8 @@ def addmm_impl(input, mat1, mat2, beta=1, alpha=1):
return F.linear(mat1, mat2.transpose(0, 1)) + input
-@register_tracer_impl(torch.addbmm, name='_bias_addition_impl')
-@register_tracer_impl(torch.Tensor.addbmm, name='_bias_addition_impl')
+@register_tracer_impl(torch.addbmm, name="_bias_addition_impl")
+@register_tracer_impl(torch.Tensor.addbmm, name="_bias_addition_impl")
def addbmm_impl(input, batch1, batch2, beta=1, alpha=1):
if alpha != 1 and beta != 1:
return torch.bmm(batch1, batch2.transpose(1, 2)) * alpha + input * beta
diff --git a/colossalai/_analyzer/fx/tracer/custom_leaf_module.py b/colossalai/_analyzer/fx/tracer/custom_leaf_module.py
index 112c7c9637d20e395dbccaace063e3fa7657041f..ff6b55be5117ec74ac7fc73e44c38b31cdd896fa 100644
--- a/colossalai/_analyzer/fx/tracer/custom_leaf_module.py
+++ b/colossalai/_analyzer/fx/tracer/custom_leaf_module.py
@@ -4,6 +4,7 @@ from .tracer import register_leaf_module, register_leaf_module_impl
try:
import apex
+
register_leaf_module(apex.normalization.FusedLayerNorm)
register_leaf_module(apex.normalization.FusedRMSNorm)
register_leaf_module(apex.normalization.MixedFusedLayerNorm)
diff --git a/colossalai/_analyzer/fx/tracer/proxy.py b/colossalai/_analyzer/fx/tracer/proxy.py
index ce379efdcf0d7c01d5541cb9ebffac1325fa97ef..e3e210e7d1900729d79b09bd9947f73017e5fe4a 100644
--- a/colossalai/_analyzer/fx/tracer/proxy.py
+++ b/colossalai/_analyzer/fx/tracer/proxy.py
@@ -1,10 +1,8 @@
import operator
-from typing import Any, Callable, Dict, Optional, Set, Union
+from typing import Any, Callable, Dict, Optional, Union
import torch
-import torch.nn as nn
-from torch.fx import Graph, Node, Proxy, Tracer
-from torch.fx.graph import _Namespace
+from torch.fx import Node, Proxy
from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor
@@ -32,7 +30,7 @@ class ColoProxy(Proxy):
def __torch_function__(cls, orig_method, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
if orig_method in cls._func_dispatch:
- impl = cls._func_dispatch.pop(orig_method) # avoid recursion
+ impl = cls._func_dispatch.pop(orig_method) # avoid recursion
proxy = impl(*args, **kwargs)
cls._func_dispatch[orig_method] = impl
return proxy
@@ -72,7 +70,7 @@ class ColoProxy(Proxy):
return ColoAttribute(self, k, getattr(self._meta_data, k, None))
def __setitem__(self, key, value):
- proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {})
+ proxy = self.tracer.create_proxy("call_function", operator.setitem, (self, key, value), {})
proxy.meta_data = self._meta_data
return proxy
@@ -89,7 +87,6 @@ class ColoProxy(Proxy):
class ColoAttribute(ColoProxy):
-
def __init__(self, root, attr: str, data=None):
self.root = root
self.attr = attr
@@ -102,11 +99,11 @@ class ColoAttribute(ColoProxy):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if self._node is None:
- self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
+ self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
return self._node
def __call__(self, *args, **kwargs):
- return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
+ return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
def __repr__(self):
return f"ColoAttribute({self.node.name}, attr={self.attr})"
diff --git a/colossalai/_analyzer/fx/tracer/symbolic_trace.py b/colossalai/_analyzer/fx/tracer/symbolic_trace.py
index 2018863f6f5f50de7cc61eafb907a55034711993..7884fd911c864dcab53876950caea8024a6f2544 100644
--- a/colossalai/_analyzer/fx/tracer/symbolic_trace.py
+++ b/colossalai/_analyzer/fx/tracer/symbolic_trace.py
@@ -1,4 +1,4 @@
-from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
+from typing import Any, Callable, Dict, Optional, Union
import torch
from torch.fx import Tracer
@@ -8,6 +8,7 @@ from colossalai._analyzer._subclasses import MetaTensor
try:
from ..codegen import ActivationCheckpointCodeGen
+
SUPPORT_ACTIVATION = True
except:
SUPPORT_ACTIVATION = False
@@ -16,7 +17,7 @@ from .tracer import ColoTracer
def _default_device():
- return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
+ return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
def _current_device(module: torch.nn.Module):
@@ -144,10 +145,9 @@ def symbolic_trace(
if meta_args:
device, orig_device = _default_device(), _current_device(root)
wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem
- graph = ColoTracer(trace_act_ckpt=trace_act_ckpt,
- bias_addition_split=bias_addition_split).trace(root.to(device),
- concrete_args=concrete_args,
- meta_args=tree_map(wrap_fn, meta_args))
+ graph = ColoTracer(trace_act_ckpt=trace_act_ckpt, bias_addition_split=bias_addition_split).trace(
+ root.to(device), concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args)
+ )
if trace_act_ckpt and SUPPORT_ACTIVATION:
graph.set_codegen(ActivationCheckpointCodeGen())
root.to(orig_device)
diff --git a/colossalai/_analyzer/fx/tracer/tracer.py b/colossalai/_analyzer/fx/tracer/tracer.py
index 6958a00a6a72af16bf6a9736a7c18411ff127b76..17dce767269d4efc435ae9dcc2e666073035f107 100644
--- a/colossalai/_analyzer/fx/tracer/tracer.py
+++ b/colossalai/_analyzer/fx/tracer/tracer.py
@@ -20,11 +20,10 @@ def _truncate_suffix(s: str):
import re
# FIXME: don't know why but torch.fx always gets a suffix like '_1' in the name
- return re.sub(r'_\d+$', '', s)
+ return re.sub(r"_\d+$", "", s)
-def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custom_impl'):
-
+def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = "_custom_impl"):
def wrapper(impl):
assert hasattr(ColoTracer, name), f"Cannot register {func.__name__} in ColoTracer.{name}"
getattr(ColoTracer, name)[func] = impl
@@ -34,7 +33,6 @@ def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custo
def register_leaf_module_impl(module: nn.Module):
-
def wrapper(impl):
ColoTracer._custom_leaf_module_impl[module] = impl
return impl
@@ -76,7 +74,7 @@ class ColoTracer(Tracer):
self.ckpt_regions = []
self.ckpt_idx = 0
- self.mod_dir = ''
+ self.mod_dir = ""
# whether the tracer should split the bias_add ops into two ops
self.bias_addition_split = bias_addition_split
@@ -87,35 +85,41 @@ class ColoTracer(Tracer):
if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None:
return False
# user can specify which modules are leaf modules and which are not
- return (type(m) not in self._custom_non_leaf_module
- and (type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)))
+ return type(m) not in self._custom_non_leaf_module and (
+ type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)
+ )
- def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...],
- kwargs: Dict[str, Any]) -> Any:
+ def call_module(
+ self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]
+ ) -> Any:
curr_dir = self.mod_dir
- self.mod_dir = 'self.' + self.path_of_module(m)
+ self.mod_dir = "self." + self.path_of_module(m)
rst = super().call_module(m, forward, args, kwargs)
self.mod_dir = curr_dir
return rst
- def proxy(self, node: Node) -> 'ColoProxy':
+ def proxy(self, node: Node) -> "ColoProxy":
return ColoProxy(node, self)
- def create_proxy(self,
- kind: str,
- target: Target,
- args: Tuple[Any, ...],
- kwargs: Dict[str, Any],
- name: Optional[str] = None,
- type_expr: Optional[Any] = None,
- proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
-
+ def create_proxy(
+ self,
+ kind: str,
+ target: Target,
+ args: Tuple[Any, ...],
+ kwargs: Dict[str, Any],
+ name: Optional[str] = None,
+ type_expr: Optional[Any] = None,
+ proxy_factory_fn: Callable[[Node], "Proxy"] = None,
+ ):
proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
- if kind == 'placeholder':
- proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get(
- _truncate_suffix(target), None)
- elif kind == 'get_attr':
+ if kind == "placeholder":
+ proxy.meta_data = (
+ self.meta_args[target]
+ if target in self.meta_args
+ else self.concrete_args.get(_truncate_suffix(target), None)
+ )
+ elif kind == "get_attr":
self.disable_module_getattr = True
try:
attr_itr = self.root
@@ -125,20 +129,21 @@ class ColoTracer(Tracer):
proxy.meta_data = attr_itr
finally:
self.disable_module_getattr = False
- elif kind == 'call_function':
+ elif kind == "call_function":
proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
- elif kind == 'call_method':
+ elif kind == "call_method":
self.disable_module_getattr = True
try:
- if target == '__call__':
+ if target == "__call__":
proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
else:
if target not in _TensorPropertyMethod:
- proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
- **tree_map(unwrap_fn, kwargs))
+ proxy._meta_data = getattr(unwrap_fn(args[0]), target)(
+ *tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)
+ )
finally:
self.disable_module_getattr = False
- elif kind == 'call_module':
+ elif kind == "call_module":
mod = self.root.get_submodule(target)
self.disable_module_getattr = True
try:
@@ -158,11 +163,12 @@ class ColoTracer(Tracer):
n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions))
return node
- def trace(self,
- root: torch.nn.Module,
- concrete_args: Optional[Dict[str, torch.Tensor]] = None,
- meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph:
-
+ def trace(
+ self,
+ root: torch.nn.Module,
+ concrete_args: Optional[Dict[str, torch.Tensor]] = None,
+ meta_args: Optional[Dict[str, torch.Tensor]] = None,
+ ) -> Graph:
if meta_args is None:
meta_args = {}
@@ -177,9 +183,7 @@ class ColoTracer(Tracer):
non_concrete_arg_names = sig_names - concrete_arg_names
# update concrete args with default values
for k, v in sig.parameters.items():
- if k in sig_names - meta_arg_names and \
- k not in concrete_args and \
- v.default is not inspect.Parameter.empty:
+ if k in sig_names - meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
def _check_arg_name_valid(names: Iterable[str]):
@@ -194,9 +198,9 @@ class ColoTracer(Tracer):
self.meta_args = meta_args
with self._torch_factory_override(), self._tracer_override(), torch.no_grad():
- self.mod_dir = 'self'
+ self.mod_dir = "self"
self.graph = super().trace(root, concrete_args=concrete_args)
- self.mod_dir = ''
+ self.mod_dir = ""
self.graph.lint()
for node in self.graph.nodes:
@@ -266,17 +270,17 @@ class ColoTracer(Tracer):
# override the torch factory functions to create a proxy when the method
# is called during ``symbolic_trace()``.
def wrap_factory_method(target):
-
@functools.wraps(target)
def wrapper(*args, **kwargs):
is_proxy = any(isinstance(p, ColoProxy) for p in args) | any(
- isinstance(p, ColoProxy) for p in kwargs.values())
+ isinstance(p, ColoProxy) for p in kwargs.values()
+ )
if is_proxy:
# if the arg is a proxy, then need to record this function called on this proxy
# e.g. torch.ones(size) where size is an input proxy
self.disable_module_getattr = True
try:
- proxy = self.create_proxy('call_function', target, args, kwargs)
+ proxy = self.create_proxy("call_function", target, args, kwargs)
finally:
self.disable_module_getattr = False
return proxy
@@ -341,10 +345,13 @@ class ColoTracer(Tracer):
if attr_val is p:
if n not in parameter_proxy_cache:
kwargs = {}
- if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters:
- kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else
- lambda node: ColoProxy(self, node, n, attr_val))
- val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type]
+ if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
+ kwargs["proxy_factory_fn"] = (
+ None
+ if not self.param_shapes_constant
+ else lambda node: ColoProxy(self, node, n, attr_val)
+ )
+ val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
@@ -355,8 +362,9 @@ class ColoTracer(Tracer):
return maybe_buffer_proxy
if isinstance(attr_val, torch.nn.Parameter):
- maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
- parameter_proxy_cache)
+ maybe_parameter_proxy = maybe_get_proxy_for_attr(
+ attr_val, self.root.named_parameters(), parameter_proxy_cache
+ )
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy
diff --git a/colossalai/amp/__init__.py b/colossalai/amp/__init__.py
index 963215476b6b038b2aa33c124461387e47579d3c..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644
--- a/colossalai/amp/__init__.py
+++ b/colossalai/amp/__init__.py
@@ -1,54 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import torch.nn as nn
-from torch.nn.modules.loss import _Loss
-from torch.optim import Optimizer
-
-from colossalai.context import Config
-
-from .amp_type import AMP_TYPE
-from .apex_amp import convert_to_apex_amp
-from .naive_amp import convert_to_naive_amp
-from .torch_amp import convert_to_torch_amp
-
-__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE']
-
-
-def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None):
- """A helper function to wrap training components with Torch AMP modules.
-
- Args:
- param model (:class:`torch.nn.Module`): your model object.
- optimizer (:class:`torch.optim.Optimizer`): your optimizer object.
- criterion (:class:`torch.nn.modules.loss._Loss`): your loss function object.
- mode (:class:`colossalai.amp.AMP_TYPE`): amp mode.
- amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for different amp modes.
-
- Returns:
- A tuple (model, optimizer, criterion).
-
- Note:
- ``amp_config`` may vary from different mode you choose. You should check the corresponding amp mode
- for more details about ``amp_config``.
- For ``apex_amp``, please check
- `apex_amp config `_.
- For ``naive_amp``, please check
- `naive_amp config `_.
- For ``torch_amp``, please check
- `torch_amp config `_.
- """
- assert isinstance(mode, AMP_TYPE), \
- f'expected the argument mode be AMP_TYPE, but got {type(mode)}'
-
- if amp_config is None:
- amp_config = Config()
-
- if mode == AMP_TYPE.TORCH:
- model, optimizer, criterion = convert_to_torch_amp(model, optimizer, criterion, amp_config)
- elif mode == AMP_TYPE.APEX:
- model, optimizer = convert_to_apex_amp(model, optimizer, amp_config)
- elif mode == AMP_TYPE.NAIVE:
- model, optimizer = convert_to_naive_amp(model, optimizer, amp_config)
-
- return model, optimizer, criterion
diff --git a/colossalai/amp/amp_type.py b/colossalai/amp/amp_type.py
deleted file mode 100644
index 6f322f866cfc813e66e54b0c1006d62ef949e96e..0000000000000000000000000000000000000000
--- a/colossalai/amp/amp_type.py
+++ /dev/null
@@ -1,10 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from enum import Enum
-
-
-class AMP_TYPE(Enum):
- APEX = 'apex'
- TORCH = 'torch'
- NAIVE = 'naive'
diff --git a/colossalai/amp/apex_amp/__init__.py b/colossalai/amp/apex_amp/__init__.py
deleted file mode 100644
index 51b9b97dccce877783251fb3f61f08a87a6a7659..0000000000000000000000000000000000000000
--- a/colossalai/amp/apex_amp/__init__.py
+++ /dev/null
@@ -1,42 +0,0 @@
-import torch.nn as nn
-from torch.optim import Optimizer
-
-from .apex_amp import ApexAMPOptimizer
-
-
-def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config):
- r"""A helper function to wrap training components with Apex AMP modules
-
- Args:
- model (:class:`torch.nn.Module`): your model object.
- optimizer (:class:`torch.optim.Optimizer`): your optimizer object.
- amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for initializing apex_amp.
-
- Returns:
- Tuple: A tuple (model, optimizer).
-
- The ``amp_config`` should include parameters below:
- ::
-
- enabled (bool, optional, default=True)
- opt_level (str, optional, default="O1")
- cast_model_type (``torch.dtype``, optional, default=None)
- patch_torch_functions (bool, optional, default=None)
- keep_batchnorm_fp32 (bool or str, optional, default=None
- master_weights (bool, optional, default=None)
- loss_scale (float or str, optional, default=None)
- cast_model_outputs (torch.dtype, optional, default=None)
- num_losses (int, optional, default=1)
- verbosity (int, default=1)
- min_loss_scale (float, default=None)
- max_loss_scale (float, default=2.**24)
-
- More details about ``amp_config`` refer to `amp_config `_.
- """
- import apex.amp as apex_amp
- model, optimizer = apex_amp.initialize(model, optimizer, **amp_config)
- optimizer = ApexAMPOptimizer(optimizer)
- return model, optimizer
-
-
-__all__ = ['convert_to_apex_amp', 'ApexAMPOptimizer']
diff --git a/colossalai/amp/naive_amp/__init__.py b/colossalai/amp/naive_amp/__init__.py
index 5b2f71d3ced771c43d541843153c6b64613f69e1..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644
--- a/colossalai/amp/naive_amp/__init__.py
+++ b/colossalai/amp/naive_amp/__init__.py
@@ -1,60 +0,0 @@
-import inspect
-
-import torch.nn as nn
-from torch.optim import Optimizer
-
-from colossalai.utils import is_no_pp_or_last_stage
-
-from ._fp16_optimizer import FP16Optimizer
-from .grad_scaler import ConstantGradScaler, DynamicGradScaler
-from .naive_amp import NaiveAMPModel, NaiveAMPOptimizer
-
-
-def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
- """A helper function to wrap training components with naive AMP modules. In this mode,
- we forcibly cast the model weights and inputs to FP16, and cast the model outputs to FP32 to calculate loss,
- which is equivalent to Apex O3.
-
- Args:
- model (:class:`torch.nn.Module`): your model object
- optimizer (:class:`torch.optim.Optimizer`): your optimizer object
- amp_config (:class:`colossalai.context.Config` or dict): configuration for naive mode amp.
-
- Returns:
- Tuple: A tuple (model, optimizer)
-
- The ``amp_config`` should contain parameters below::
-
- verbose (bool, optional): if set to `True`, will print debug info (Default: False).
- clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0).
- Note that clipping is ignored if clip_grad == 0.
- dynamic_grad_scale (bool): whether to use dynamic grad scaler.
- """
- if isinstance(model, nn.ModuleList):
- # interleaved pipeline
- module_list = []
- for chunk, m in enumerate(model):
- output_to_fp32 = is_no_pp_or_last_stage() and chunk == len(model) - 1
- module_list.append(NaiveAMPModel(m, output_to_fp32=output_to_fp32))
- model = nn.ModuleList(module_list)
- else:
- output_to_fp32 = is_no_pp_or_last_stage()
- model = NaiveAMPModel(model, output_to_fp32=output_to_fp32)
-
- use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True)
- if use_dynamic_grad_scaler:
- scaler_class = DynamicGradScaler
- else:
- scaler_class = ConstantGradScaler
-
- sig = inspect.signature(scaler_class.__init__)
- kwargs = dict()
- for param in sig.parameters.values():
- if param.name in amp_config:
- kwargs[param.name] = amp_config.pop(param.name)
- grad_scaler = scaler_class(**kwargs)
- optimizer = NaiveAMPOptimizer(optimizer, grad_scaler, **amp_config)
- return model, optimizer
-
-
-__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer', 'FP16Optimizer']
diff --git a/colossalai/amp/naive_amp/_utils.py b/colossalai/amp/naive_amp/_utils.py
deleted file mode 100644
index 7633705e19fbce24faec87f9691c834279f0d8ad..0000000000000000000000000000000000000000
--- a/colossalai/amp/naive_amp/_utils.py
+++ /dev/null
@@ -1,49 +0,0 @@
-from typing import List
-
-from torch import Tensor
-
-
-def has_inf_or_nan(tensor):
- """Check if tensor has inf or nan values.
-
- Args:
- tensor (:class:`torch.Tensor`): a torch tensor object
-
- Returns:
- bool: Whether the tensor has inf or nan. True for yes and False for no.
- """
- try:
- # if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
- # Pytorch's .sum() creates a one-element tensor of the same type as tensor
- # (which is true for some recent version of pytorch).
- tensor_sum = float(tensor.float().sum())
- # More efficient version that can be used if .sum() returns a Python scalar
- # tensor_sum = float(tensor.sum())
- except RuntimeError as instance:
- # We want to check if inst is actually an overflow exception.
- # RuntimeError could come from a different error.
- # If so, we still want the exception to propagate.
- if "value cannot be converted" not in instance.args[0]:
- raise
- return True
- else:
- if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum:
- return True
- return False
-
-
-def zero_gard_by_list(tensor_list: List[Tensor], set_to_none: bool = True) -> None:
- """Clear the gradient of a list of tensors,
-
- Note: copied from torch.optim.optimizer.
- """
- for param in tensor_list:
- if param.grad is not None:
- if set_to_none:
- param.grad = None
- else:
- if param.grad.grad_fn is not None:
- param.grad.detach_()
- else:
- param.grad.requires_grad_(False)
- param.grad.zero_()
diff --git a/colossalai/amp/naive_amp/grad_scaler/__init__.py b/colossalai/amp/naive_amp/grad_scaler/__init__.py
index dc8499d877e13f8e0eb14317d2cf4a8d54dfcb2a..34a20e8d67d6690a5a6e1e9ca945ab1d2e31adf7 100644
--- a/colossalai/amp/naive_amp/grad_scaler/__init__.py
+++ b/colossalai/amp/naive_amp/grad_scaler/__init__.py
@@ -2,4 +2,4 @@ from .base_grad_scaler import BaseGradScaler
from .constant_grad_scaler import ConstantGradScaler
from .dynamic_grad_scaler import DynamicGradScaler
-__all__ = ['BaseGradScaler', 'ConstantGradScaler', 'DynamicGradScaler']
+__all__ = ["BaseGradScaler", "ConstantGradScaler", "DynamicGradScaler"]
diff --git a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
index 0d84384a7f67c6a4521a86d34f71ff03b821c7be..79661a44424fb0c8386ebdd04a7051ddae1517e3 100644
--- a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
+++ b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
@@ -9,7 +9,7 @@ from torch import Tensor
from colossalai.logging import get_dist_logger
-__all__ = ['BaseGradScaler']
+__all__ = ["BaseGradScaler"]
class BaseGradScaler(ABC):
@@ -30,24 +30,21 @@ class BaseGradScaler(ABC):
@property
def scale(self) -> Tensor:
- """Returns the loss scale.
- """
+ """Returns the loss scale."""
return self._scale
@property
def inv_scale(self) -> Tensor:
- """Returns the inverse of the loss scale.
- """
+ """Returns the inverse of the loss scale."""
return self._scale.double().reciprocal().float()
def state_dict(self) -> Dict:
- """Returns the states of the gradient scaler as a dict object.
- """
+ """Returns the states of the gradient scaler as a dict object."""
state_dict = dict()
- state_dict['scale'] = self.scale
+ state_dict["scale"] = self.scale
return state_dict
def load_state_dict(self, state_dict: Dict) -> None:
@@ -57,7 +54,7 @@ class BaseGradScaler(ABC):
state_dict (dict): the states of the gradient scaler
"""
- self._scale = state_dict['scale']
+ self._scale = state_dict["scale"]
@abstractmethod
def update(self, overflow: bool) -> None:
@@ -67,8 +64,6 @@ class BaseGradScaler(ABC):
overflow (bool): whether overflow occurs
"""
- pass
-
def log(self, message, *args, **kwargs):
"""Log messages.
diff --git a/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py
index a2f518c5dd28261f98faadf9134721ef1fd67dc7..2ad8b51ac22c1aa69cc4c48996b8a1098dad5f4a 100644
--- a/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py
+++ b/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py
@@ -2,7 +2,7 @@
# -*- encoding: utf-8 -*-
from .base_grad_scaler import BaseGradScaler
-__all__ = ['ConstantGradScaler']
+__all__ = ["ConstantGradScaler"]
class ConstantGradScaler(BaseGradScaler):
@@ -23,4 +23,3 @@ class ConstantGradScaler(BaseGradScaler):
Args:
overflow (bool): whether overflow occurs
"""
- pass
diff --git a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
index e899b9ca4c89fba16352ce736cb0abc4959e163b..65133a4b3712be6069bfeab5c35bf0a556c585be 100644
--- a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
+++ b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
@@ -7,7 +7,7 @@ import torch
from .base_grad_scaler import BaseGradScaler
-__all__ = ['DynamicGradScaler']
+__all__ = ["DynamicGradScaler"]
class DynamicGradScaler(BaseGradScaler):
@@ -24,15 +24,17 @@ class DynamicGradScaler(BaseGradScaler):
verbose (bool): whether to log messages, defaults to False
"""
- def __init__(self,
- initial_scale: float = 2**16,
- growth_factor: float = 2,
- backoff_factor: float = 0.5,
- growth_interval: int = 1000,
- min_scale: Optional[float] = None,
- max_scale: Optional[float] = None,
- hysteresis: int = 2,
- verbose: bool = False):
+ def __init__(
+ self,
+ initial_scale: float = 2**16,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ min_scale: Optional[float] = None,
+ max_scale: Optional[float] = None,
+ hysteresis: int = 2,
+ verbose: bool = False,
+ ):
super().__init__(initial_scale, verbose)
if min_scale:
self._min_scale = torch.cuda.FloatTensor([min_scale])
@@ -53,18 +55,17 @@ class DynamicGradScaler(BaseGradScaler):
self._sanity_checks()
def _sanity_checks(self) -> None:
- """Check if the arguments are correct.
- """
+ """Check if the arguments are correct."""
if self._min_scale:
- assert self._min_scale > 0, 'The minimum gradient scale cannot be zero or negative'
- assert self._min_scale <= self._scale, 'The minimum gradient scale cannot be greater than the current scale'
+ assert self._min_scale > 0, "The minimum gradient scale cannot be zero or negative"
+ assert self._min_scale <= self._scale, "The minimum gradient scale cannot be greater than the current scale"
if self._max_scale:
- assert self._max_scale > 0, 'The maximum gradient scale cannot be zero or negative'
- assert self._max_scale >= self._scale, 'The maximum gradient scale cannot be smaller than the current scale'
- assert self._growth_factor > 1, 'The growth factor cannot be equal or smaller than 1'
- assert 0 < self._backoff_factor < 1, 'The backoff factor must be between 0 and 1'
- assert self._hysteresis >= 0, 'The hysteresis cannot be negative'
+ assert self._max_scale > 0, "The maximum gradient scale cannot be zero or negative"
+ assert self._max_scale >= self._scale, "The maximum gradient scale cannot be smaller than the current scale"
+ assert self._growth_factor > 1, "The growth factor cannot be equal or smaller than 1"
+ assert 0 < self._backoff_factor < 1, "The backoff factor must be between 0 and 1"
+ assert self._hysteresis >= 0, "The hysteresis cannot be negative"
def update(self, overflow: bool) -> None:
"""Update the loss scale.
@@ -88,19 +89,18 @@ class DynamicGradScaler(BaseGradScaler):
self.log(
f"No overflow for consecutive {self._growth_interval} steps, "
f"the loss scale is adjusted to {self.scale.item()}",
- ranks=[0])
+ ranks=[0],
+ )
def _backoff_scale(self) -> None:
- """Decrease the loss scale
- """
+ """Decrease the loss scale"""
self._scale = self._scale * self._backoff_factor
if self._min_scale:
self._scale = torch.max(self._scale, self._min_scale)
def _grow_scale(self) -> None:
- """Increase the loss scale
- """
+ """Increase the loss scale"""
self._scale = self._scale * self._growth_factor
if self._max_scale:
@@ -108,14 +108,14 @@ class DynamicGradScaler(BaseGradScaler):
def state_dict(self):
state_dict = dict()
- state_dict['scale'] = self._scale
- state_dict['growth_factor'] = self._growth_factor
- state_dict['backoff_factor'] = self._backoff_factor
- state_dict['hysteresis'] = self._hysteresis
+ state_dict["scale"] = self._scale
+ state_dict["growth_factor"] = self._growth_factor
+ state_dict["backoff_factor"] = self._backoff_factor
+ state_dict["hysteresis"] = self._hysteresis
return state_dict
def load_state_dict(self, state_dict):
- self._scale = state_dict['scale'].cuda(torch.cuda.current_device())
- self._growth_factor = state_dict['growth_factor']
- self._backoff_factor = state_dict['backoff_factor']
- self._hysteresis = state_dict['hysteresis']
+ self._scale = state_dict["scale"].cuda(torch.cuda.current_device())
+ self._growth_factor = state_dict["growth_factor"]
+ self._backoff_factor = state_dict["backoff_factor"]
+ self._hysteresis = state_dict["hysteresis"]
diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py b/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a31811e4a567db889e07830f46c01ec4b1c27a53
--- /dev/null
+++ b/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py
@@ -0,0 +1,9 @@
+from .base import MixedPrecisionMixin
+from .bf16 import BF16MixedPrecisionMixin
+from .fp16 import FP16MixedPrecisionMixin
+
+__all__ = [
+ "MixedPrecisionMixin",
+ "FP16MixedPrecisionMixin",
+ "BF16MixedPrecisionMixin",
+]
diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/base.py b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc7e0b74179a4b18710b30a136a8ed9f957f68b1
--- /dev/null
+++ b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py
@@ -0,0 +1,86 @@
+from abc import ABC, abstractmethod
+
+import torch
+from torch import Tensor
+
+
+class MixedPrecisionMixin(ABC):
+ """A helper class for mixed precision training. This mixin is used in mixed precision optimizers.
+
+ Attributes:
+ dtype (torc.dtype): The expected dtype of the gradients.
+
+ Examples:
+ ```python
+ class MyMixedPrecisionOptimizer(OptimizerWrapper):
+ def __init__(self, optim: Optimizer):
+ super().__init__(optim)
+ self.mixed_precision = MixedPrecisionMixin()
+
+ def backward(self, loss):
+ loss = self.mixed_precision.pre_backward(loss)
+ loss.backward()
+
+ def backward_by_grad(self, tensor, grad):
+ grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
+ tensor.backward(grad)
+
+ def step(self):
+ if self.mixed_precision.should_skip_step():
+ self.zero_grad()
+ return
+ div_scale = self.mixed_precision.get_grad_div_scale()
+ # maybe clip grad here
+ # maybe scale grad here
+ self.optim.step()
+
+ def zero_grad(self):
+ self.mixed_precision.pre_zero_grad()
+ return self.optim.zero_grad()
+ ```
+ """
+
+ dtype: torch.dtype
+
+ @abstractmethod
+ def pre_backward(self, loss: Tensor) -> Tensor:
+ """Called before backward.
+
+ Args:
+ loss (Tensor): Loss value.
+
+ Returns:
+ Tensor: Loss value (possibly scaled).
+ """
+
+ @abstractmethod
+ def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
+ """Called before backward by grad. This is helpful for pipeline parallelism.
+
+ Args:
+ tensor (Tensor): Tensor to backward.
+ grad (Tensor): Gradient of the tensor.
+
+ Returns:
+ Tensor: Gradient of the tensor (possibly scaled).
+ """
+
+ @abstractmethod
+ def should_skip_step(self) -> bool:
+ """Called before step.
+
+ Returns:
+ bool: Whether to skip the step.
+ """
+
+ @abstractmethod
+ def pre_zero_grad(self) -> None:
+ """Called before zero_grad."""
+
+ @abstractmethod
+ def get_grad_div_scale(self) -> float:
+ """Called before step or clip_grad. To keep computation efficiency, this method does not (maybe) unscale grads.
+
+ Returns:
+ float: A divisor for gradient clipping or step.
+ """
diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/bf16.py b/colossalai/amp/naive_amp/mixed_precision_mixin/bf16.py
new file mode 100644
index 0000000000000000000000000000000000000000..9454f6eb84130f083ab893baa34b9277a71dd5c0
--- /dev/null
+++ b/colossalai/amp/naive_amp/mixed_precision_mixin/bf16.py
@@ -0,0 +1,23 @@
+import torch
+from torch import Tensor
+
+from .base import MixedPrecisionMixin
+
+
+class BF16MixedPrecisionMixin(MixedPrecisionMixin):
+ dtype = torch.bfloat16
+
+ def pre_backward(self, loss: Tensor) -> Tensor:
+ return loss
+
+ def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
+ return grad
+
+ def should_skip_step(self) -> bool:
+ return False
+
+ def pre_zero_grad(self) -> None:
+ pass
+
+ def get_grad_div_scale(self) -> float:
+ return 1.0
diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ce272356797db4073b980e6f7713c83a1e0c44d
--- /dev/null
+++ b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py
@@ -0,0 +1,87 @@
+from abc import abstractmethod
+from enum import Enum
+
+import torch
+import torch.distributed as dist
+from torch import Tensor
+
+from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
+from colossalai.utils import get_current_device
+
+from .base import MixedPrecisionMixin
+
+
+class OptimState(Enum):
+ SCALED = 0
+ UNSCALED = 1
+
+
+class FP16MixedPrecisionMixin(MixedPrecisionMixin):
+ dtype = torch.float16
+
+ def __init__(
+ self,
+ initial_scale: float = 2**16,
+ min_scale: float = 1,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ max_scale: float = 2**32,
+ ) -> None:
+ super().__init__()
+ self.grad_scaler = DynamicGradScaler(
+ initial_scale=initial_scale,
+ min_scale=min_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ hysteresis=hysteresis,
+ max_scale=max_scale,
+ )
+ self.optim_state = OptimState.UNSCALED
+ self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device())
+
+ @property
+ def loss_scale(self) -> float:
+ return self.grad_scaler.scale.item()
+
+ @abstractmethod
+ def check_local_overflow(self) -> bool:
+ """Check whether there is overflow in the local process. This method should be implemented by subclasses.
+
+ Returns:
+ bool: Whether there is overflow in the local process.
+ """
+
+ def check_overflow(self) -> bool:
+ # clear previous overflow record
+ self.found_overflow.fill_(0.0)
+ if self.check_local_overflow():
+ self.found_overflow.fill_(1.0)
+ dist.all_reduce(self.found_overflow, op=dist.ReduceOp.MAX)
+ return self.found_overflow.item() > 0
+
+ def pre_backward(self, loss: Tensor) -> Tensor:
+ loss = self.loss_scale * loss
+ self.optim_state = OptimState.SCALED
+ return loss
+
+ def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
+ self.optim_state = OptimState.SCALED
+ return grad
+
+ def should_skip_step(self) -> bool:
+ found_inf = self.check_overflow()
+ self.grad_scaler.update(found_inf)
+ if found_inf:
+ self.optim_state = OptimState.UNSCALED
+ return found_inf
+
+ def pre_zero_grad(self) -> None:
+ pass
+
+ def get_grad_div_scale(self) -> float:
+ assert self.optim_state == OptimState.SCALED, "grads should be scaled before clipping"
+ self.optim_state = OptimState.UNSCALED
+ return self.loss_scale
diff --git a/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/colossalai/amp/naive_amp/mixed_precision_optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..501a843f6992c2f9d23fc4ad93e059d581910854
--- /dev/null
+++ b/colossalai/amp/naive_amp/mixed_precision_optimizer.py
@@ -0,0 +1,169 @@
+from typing import Dict, List
+
+import torch
+from torch import Tensor
+from torch.nn import Module, Parameter
+from torch.optim import Optimizer
+
+from colossalai.interface import OptimizerWrapper
+
+from .mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
+
+
+class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
+ def __init__(
+ self,
+ working_params: List[Parameter],
+ initial_scale: float = 2**16,
+ min_scale: float = 1,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ max_scale: float = 2**32,
+ ) -> None:
+ super().__init__(
+ initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale
+ )
+ self.params = working_params
+
+ def check_local_overflow(self) -> bool:
+ for p in self.params:
+ if p.grad is not None and not torch.isfinite(p.grad).all():
+ return True
+ return False
+
+
+class MixedPrecisionOptimizer(OptimizerWrapper):
+ def __init__(
+ self,
+ optim: Optimizer,
+ precision: str = "fp16",
+ initial_scale: float = 2**16,
+ min_scale: float = 1,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ max_scale: float = 2**32,
+ max_norm: float = 0.0,
+ ):
+ super().__init__(optim)
+ if precision == "fp16":
+ working_params = []
+ for group in self.optim.param_groups:
+ for p in group["params"]:
+ working_params.append(p)
+ self.mixed_precision = NaiveFP16MixedPrecisionMixin(
+ working_params,
+ initial_scale=initial_scale,
+ min_scale=min_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ hysteresis=hysteresis,
+ max_scale=max_scale,
+ )
+ elif precision == "bf16":
+ self.mixed_precision = BF16MixedPrecisionMixin()
+ else:
+ raise ValueError(f"Unsupported precision: {precision}")
+ if max_norm > 0.0:
+ raise NotImplementedError("max_norm is not supported yet.")
+ self.max_norm = max_norm
+ self.working_to_master_map: Dict[Parameter, Tensor] = {}
+ self.master_to_working_map: Dict[Tensor, Parameter] = {}
+
+ # create master weights
+ for group in self.optim.param_groups:
+ master_params = []
+ for p in group["params"]:
+ if p.requires_grad:
+ master_p = p
+ if p.dtype != torch.float:
+ master_p = p.detach().float()
+ self.working_to_master_map[p] = master_p
+ self.master_to_working_map[master_p] = p
+ master_params.append(master_p)
+ group["params"] = master_params
+
+ def backward(self, loss: Tensor, *args, **kwargs):
+ loss = self.mixed_precision.pre_backward(loss)
+ loss.backward(*args, **kwargs)
+
+ def backward_by_grad(self, tensor: Tensor, grad: Tensor):
+ grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
+ tensor.backward(grad)
+
+ def zero_grad(self, *args, **kwargs):
+ for p in self.working_to_master_map.keys():
+ p.grad = None
+ self.mixed_precision.pre_zero_grad()
+ return super().zero_grad(*args, **kwargs)
+
+ def _unscale_and_clip_grads(self, total_norm: float) -> None:
+ div_scale = 1.0
+ if self.mixed_precision is not None:
+ div_scale = self.mixed_precision.get_grad_div_scale()
+
+ if self.max_norm > 0.0:
+ # norm is in fact norm*scale
+ clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
+ if clip > 1:
+ div_scale = clip * div_scale
+
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p.grad is None:
+ continue
+ p.grad.data.mul_(1.0 / div_scale)
+
+ def _compute_grad_norm(self) -> float:
+ if self.max_norm <= 0.0:
+ return 0.0
+ grads = [p.grad for group in self.param_groups for p in group["params"] if p.grad is not None]
+ if len(grads) == 0:
+ return 0.0
+ device = grads[0].device
+ # TODO(ver217): support tp
+ total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
+ return total_norm.item()
+
+ def step(self, *args, **kwargs):
+ if self.mixed_precision.should_skip_step():
+ self.zero_grad()
+ return
+ # prepare grads
+ for group in self.optim.param_groups:
+ for p in group["params"]:
+ working_param = self.master_to_working_map[p]
+ if p is working_param:
+ continue
+ if working_param.grad is not None:
+ p.grad = working_param.grad.data.float()
+ working_param.grad = None
+ total_norm = self._compute_grad_norm()
+ self._unscale_and_clip_grads(total_norm)
+ self.optim.step(*args, **kwargs)
+ # update working params
+ for group in self.optim.param_groups:
+ for p in group["params"]:
+ working_param = self.master_to_working_map[p]
+ if p is working_param:
+ continue
+ working_param.data.copy_(p.data)
+
+ def update_master_params(self, model: Module):
+ # Update master params from working params
+ with torch.no_grad():
+ for p in model.parameters():
+ if (p is None) or (p not in self.working_to_master_map):
+ continue
+ master_param = self.working_to_master_map[p]
+ master_param.data.copy_(p.data)
+
+ def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
+ return {id(working_p): master_p for working_p, master_p in self.working_to_master_map.items()}
+
+ def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
+ return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()}
diff --git a/colossalai/amp/torch_amp/__init__.py b/colossalai/amp/torch_amp/__init__.py
deleted file mode 100644
index 893cc890d68e423c643e6dc4bbf6343ff174a8d7..0000000000000000000000000000000000000000
--- a/colossalai/amp/torch_amp/__init__.py
+++ /dev/null
@@ -1,45 +0,0 @@
-from typing import Optional
-
-import torch.nn as nn
-from torch.nn.modules.loss import _Loss
-from torch.optim import Optimizer
-
-from colossalai.context import Config
-
-from .torch_amp import TorchAMPLoss, TorchAMPModel, TorchAMPOptimizer
-
-
-def convert_to_torch_amp(model: nn.Module,
- optimizer: Optimizer,
- criterion: Optional[_Loss] = None,
- amp_config: Optional[Config] = None):
- """A helper function to wrap training components with Pytorch AMP modules
-
- Args:
- model (:class:`torch.nn.Module`): your model object.
- optimizer (:class:`torch.optim.Optimizer`): your optimizer object
- criterion (:class:`torch.nn.modules.loss._Loss`, optional): your loss function object
- amp_config (:class:`colossalai.context.Config` or dict, optional): configuration for Pytorch AMP.
-
- The ``amp_config`` should include parameters below:
- ::
-
- init_scale (float, optional, default=2.**16)
- growth_factor (float, optional, default=2.0)
- backoff_factor (float, optional, default=0.5)
- growth_interval (int, optional, default=2000)
- enabled (bool, optional, default=True)
-
- Returns:
- A tuple (model, optimizer, criterion)
- """
- model = TorchAMPModel(model)
- if amp_config is None:
- amp_config = dict()
- optimizer = TorchAMPOptimizer(optimizer, **amp_config)
- if criterion:
- criterion = TorchAMPLoss(criterion)
- return model, optimizer, criterion
-
-
-__all__ = ['convert_to_torch_amp', 'TorchAMPModel', 'TorchAMPLoss', 'TorchAMPOptimizer']
diff --git a/colossalai/auto_parallel/README.md b/colossalai/auto_parallel/README.md
index 8e47e1bb0b4a6e8e86c1e76d600d3dae3c8be251..f011ec8ccbd7f92821c45e124b85e92eba872b16 100644
--- a/colossalai/auto_parallel/README.md
+++ b/colossalai/auto_parallel/README.md
@@ -16,8 +16,8 @@ A *symbolic profiler* for collecting computing and memory overhead related to st
### Solver
**Solver** is designed to find the optimal execution plan for a given computation graph and cluster in two stages:
-1) *Intra-op parallelism stage* is to find the plan with the minimum total execution time of all nodes with respect to the constraint of the memory budget. The optimaztion goal of intra-op parallelism solver is modified from Alpa 's intra-op parallelsim ILP solver.
-2) *Activation checkpoint stage* is to search for the fastest execution plan that meets the memory budget on the computation graph after inserting the communication nodes by the intra-op parallelism stage. The algorithm to find optimial activation checkpoint is modified from Rotor . The reason we use two-stage optimization is that if the two tasks are formulated together, the solving time will be significantly increased, which will greatly affect the user experience of the system. On the contrary, solving in two hierarchical levels has many advantages. Firstly, compared with the computation graph with activation checkpointing, the original graph has fewer nodes, which can reduce the solving cost of intra-op parallelism solver. In addition, a more optimal solution can be found by adding the communication overhead into the activation checkpoint modeling.
+1) *Intra-op parallelism stage* is to find the plan with the minimum total execution time of all nodes with respect to the constraint of the memory budget. The optimization goal of intra-op parallelism solver is modified from Alpa 's intra-op parallelism ILP solver.
+2) *Activation checkpoint stage* is to search for the fastest execution plan that meets the memory budget on the computation graph after inserting the communication nodes by the intra-op parallelism stage. The algorithm to find optimal activation checkpoint is modified from Rotor . The reason we use two-stage optimization is that if the two tasks are formulated together, the solving time will be significantly increased, which will greatly affect the user experience of the system. On the contrary, solving in two hierarchical levels has many advantages. Firstly, compared with the computation graph with activation checkpointing, the original graph has fewer nodes, which can reduce the solving cost of intra-op parallelism solver. In addition, a more optimal solution can be found by adding the communication overhead into the activation checkpoint modeling.
### Generator
**Generator** applies the searched execution plan to the computation graph and recompiles the computation graph to optimized PyTorch code. It has *a series compile pass* to insert a communication node or do the kernel substitution as the intra-op parallelism solver required. Additionally, we implement a *code generation* feature to recognize the annotation from the activation checkpoint solver and inject the activation checkpoint block following annotation instructions.
diff --git a/colossalai/auto_parallel/checkpoint/build_c_ext.py b/colossalai/auto_parallel/checkpoint/build_c_ext.py
index af4349865a7b8dd748e34458eb3d5aeeb359b599..7de56f80525ae347946cdbab9bee7fb2048aa249 100644
--- a/colossalai/auto_parallel/checkpoint/build_c_ext.py
+++ b/colossalai/auto_parallel/checkpoint/build_c_ext.py
@@ -3,14 +3,16 @@ import os
from setuptools import Extension, setup
this_dir = os.path.dirname(os.path.abspath(__file__))
-ext_modules = [Extension(
- 'rotorc',
- sources=[os.path.join(this_dir, 'ckpt_solver_rotor.c')],
-)]
+ext_modules = [
+ Extension(
+ "rotorc",
+ sources=[os.path.join(this_dir, "ckpt_solver_rotor.c")],
+ )
+]
setup(
- name='rotor c extension',
- version='0.1',
- description='rotor c extension for faster dp computing',
+ name="rotor c extension",
+ version="0.1",
+ description="rotor c extension for faster dp computing",
ext_modules=ext_modules,
)
diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py
index b388d00ac553726f577575d5d770b98dfb873f12..8aaa690b333c6d71044b4c40726333bd453d411c 100644
--- a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py
+++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py
@@ -12,13 +12,13 @@ from colossalai.auto_parallel.passes.runtime_apply_pass import (
)
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
-__all___ = ['CheckpointSolverBase']
+__all___ = ["CheckpointSolverBase"]
def _copy_output(src: Graph, dst: Graph):
"""Copy the output node from src to dst"""
for n_src, n_dst in zip(src.nodes, dst.nodes):
- if n_src.op == 'output':
+ if n_src.op == "output":
n_dst.meta = n_src.meta
@@ -28,7 +28,6 @@ def _get_param_size(module: torch.nn.Module):
class CheckpointSolverBase(ABC):
-
def __init__(
self,
graph: Graph,
@@ -81,13 +80,10 @@ class CheckpointSolverBase(ABC):
@abstractmethod
def solve(self):
- """Solve the checkpointing problem and return the solution.
- """
- pass
+ """Solve the checkpointing problem and return the solution."""
def get_node_list(self):
- """Get the node list.
- """
+ """Get the node list."""
return [[node] for node in self.graph.nodes]
def _linearize_graph(self) -> List[List[Node]]:
@@ -140,8 +136,7 @@ class CheckpointSolverBase(ABC):
"""
def _is_inplace(n: Node):
- """Get the inplace argument from ``torch.fx.Node``
- """
+ """Get the inplace argument from ``torch.fx.Node``"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
@@ -150,19 +145,22 @@ class CheckpointSolverBase(ABC):
return inplace
def _is_shape_consistency(n: Node):
- """Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)
- """
+ """Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)"""
return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply]
- return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any(
- map(_is_shape_consistency, n.users))
+ return (
+ not sum([v for _, v in deps.items()])
+ and not any(map(_is_inplace, n.users))
+ and not any(map(_is_shape_consistency, n.users))
+ )
# make sure that item in cnode is valid
if self.cnode:
for name in self.cnode:
try:
- assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \
- f"Common node {name} is not an input of the model."
+ assert (
+ next(node for node in self.graph.nodes if node.name == name).op == "placeholder"
+ ), f"Common node {name} is not an input of the model."
except StopIteration:
raise ValueError(f"Common node name {name} not in graph.")
@@ -187,8 +185,9 @@ class CheckpointSolverBase(ABC):
region = []
# propagate common node attr if possible
- if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
- ]) or _is_cop(n.target):
+ if len(n.all_input_nodes) == len(
+ [node for node in n.all_input_nodes if node.name in self.cnode]
+ ) or _is_cop(n.target):
self.cnode.append(n.name)
else:
deps[n] = len([user for user in n.users if user.op != "output"])
diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py
index 19b2ef5987c9ebc160078339b764741a71b34dbf..ab16cc04b7304a60e69a91c8b49af91be898dab9 100644
--- a/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py
+++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py
@@ -8,11 +8,10 @@ from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
from .ckpt_solver_base import CheckpointSolverBase
-__all__ = ['CheckpointSolverChen']
+__all__ = ["CheckpointSolverChen"]
class CheckpointSolverChen(CheckpointSolverBase):
-
def __init__(self, graph: Graph, cnode: List[str] = None, num_grids: int = 6):
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
@@ -40,14 +39,14 @@ class CheckpointSolverChen(CheckpointSolverBase):
Returns:
graph (Graph): The optimized graph, should be a copy of the original graph.
"""
- checkpointable_op = ['call_module', 'call_method', 'call_function', 'get_attr']
+ checkpointable_op = ["call_module", "call_method", "call_function", "get_attr"]
ckpt = self.grid_search()
for i, seg in enumerate(ckpt):
for idx in range(*seg):
nodes = self.node_list[idx]
for n in nodes:
if n.op in checkpointable_op:
- n.meta['activation_checkpoint'] = i
+ n.meta["activation_checkpoint"] = i
return deepcopy(self.graph)
def run_chen_greedy(self, b: int = 0) -> Tuple[Set, int]:
diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py
index 21c3bf0da758bd061eaa9bcf08534e9a2df8d6cf..d10c41ae2b962242e09b5b52587f48bc1c80c118 100644
--- a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py
+++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py
@@ -1,5 +1,5 @@
from copy import deepcopy
-from typing import Any, Dict, List, Tuple
+from typing import Any, List, Tuple
from torch import Tensor
from torch.fx import Graph, Node
@@ -18,17 +18,18 @@ from colossalai.logging import get_dist_logger
from .ckpt_solver_base import CheckpointSolverBase
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Loss, Sequence
-__all__ = ['CheckpointSolverRotor']
+__all__ = ["CheckpointSolverRotor"]
class CheckpointSolverRotor(CheckpointSolverBase):
-
- def __init__(self,
- graph: Graph,
- free_memory: float = -1,
- cnode: List[str] = None,
- memory_slots: int = 500,
- optim_multiplier: float = 1.0):
+ def __init__(
+ self,
+ graph: Graph,
+ free_memory: float = -1,
+ cnode: List[str] = None,
+ memory_slots: int = 500,
+ optim_multiplier: float = 1.0,
+ ):
"""This is the simple implementation of dynamic programming algorithm rotor
in https://hal.inria.fr/hal-02352969. Some code are adapted from
https://gitlab.inria.fr/hiepacs/rotor.
@@ -85,13 +86,14 @@ class CheckpointSolverRotor(CheckpointSolverBase):
# backtrack
try:
- self.sequence = self._backtrack(chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table,
- self.back_ptr)
+ self.sequence = self._backtrack(
+ chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table, self.back_ptr
+ )
self._annotate_from_sequence(self.sequence, self.node_list)
except ValueError as e:
# using logger to annonce that the solver is failed
logger = get_dist_logger()
- logger.warning(f'Checkpoint solver failed: {e}')
+ logger.warning(f"Checkpoint solver failed: {e}")
raise ValueError
if verbose:
@@ -100,14 +102,19 @@ class CheckpointSolverRotor(CheckpointSolverBase):
return deepcopy(self.graph)
def print_chain(self):
- print('[input]', self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0])
+ print("[input]", self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0])
for idx in range(len(self.node_list) - 1):
- print(self.node_list[idx], self.chain.x[idx + 1], self.chain.xbar[idx + 1], self.chain.ftmp[idx],
- self.chain.btmp[idx])
- print(f'Chain = {self.chain}')
+ print(
+ self.node_list[idx],
+ self.chain.x[idx + 1],
+ self.chain.xbar[idx + 1],
+ self.chain.ftmp[idx],
+ self.chain.btmp[idx],
+ )
+ print(f"Chain = {self.chain}")
def print_sequence(self):
- print(f'Sequence = {self.sequence}')
+ print(f"Sequence = {self.sequence}")
@classmethod
def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain:
@@ -138,14 +145,14 @@ class CheckpointSolverRotor(CheckpointSolverBase):
btime = 0
fwd_mem_peak = 0
for n in node:
- assert isinstance(n, Node), f'{n} is not a Node'
+ assert isinstance(n, Node), f"{n} is not a Node"
if n.target == runtime_apply or n.target == runtime_comm_spec_apply:
# in this case we need to calculate memory usage directly based on the statics that hooked in node.meta
- xbar += n.meta['fwd_mem_out']
- fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'])
+ xbar += n.meta["fwd_mem_out"]
+ fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta["fwd_mem_tmp"])
else:
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
- fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n))
+ fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta["fwd_mem_tmp"] + cls._extract_unused_output(n))
# minimum flop count is required
ftime += max(calculate_fwd_time(n), 1.0)
@@ -162,14 +169,14 @@ class CheckpointSolverRotor(CheckpointSolverBase):
"""Extract input tensors from a Graph"""
input_tensors = []
for node in graph.nodes:
- if node.op == 'placeholder':
- input_tensors.append(node.meta['fwd_out'])
+ if node.op == "placeholder":
+ input_tensors.append(node.meta["fwd_out"])
return input_tensors
@staticmethod
def _extract_unused_output(node: Node) -> int:
"""Extract unused output from `torch.fx.Node`"""
- return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node)
+ return activation_size(node.meta["fwd_out"]) - calculate_fwd_out(node)
@staticmethod
def _extract_btmp(node: List[Node]) -> int:
@@ -180,8 +187,8 @@ class CheckpointSolverRotor(CheckpointSolverBase):
for k, v in deps.items():
k: Node
if v > 0:
- deps_size += k.meta['bwd_mem_out']
- if v == float('-inf'):
+ deps_size += k.meta["bwd_mem_out"]
+ if v == float("-inf"):
deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k)
return deps_size
@@ -190,12 +197,12 @@ class CheckpointSolverRotor(CheckpointSolverBase):
deps = {}
for n in reversed(node):
deps[n] = len(n.all_input_nodes)
- btmp = max(btmp, _extract_deps_size() + n.meta['bwd_mem_tmp'])
+ btmp = max(btmp, _extract_deps_size() + n.meta["bwd_mem_tmp"])
for child in n.users:
if child in deps:
deps[child] -= 1
if deps[child] <= 0:
- deps[child] = float('-inf') # free
+ deps[child] = float("-inf") # free
return btmp
@staticmethod
@@ -244,10 +251,11 @@ class CheckpointSolverRotor(CheckpointSolverBase):
if m < mmin:
cost_table[m][i][idx] = float("inf")
else:
- leaf_checkpoints = [(j,
- sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1])
- for j in range(i + 1, idx + 1)
- if m >= x[j]]
+ leaf_checkpoints = [
+ (j, sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1])
+ for j in range(i + 1, idx + 1)
+ if m >= x[j]
+ ]
if leaf_checkpoints:
best_leaf = min(leaf_checkpoints, key=lambda t: t[1])
else:
@@ -274,13 +282,16 @@ class CheckpointSolverRotor(CheckpointSolverBase):
import os
import subprocess
import sys
+
logger = get_dist_logger()
logger.info("rotorc hasn't been built! Building library...", ranks=[0])
this_dir = os.path.dirname(os.path.abspath(__file__))
result = subprocess.Popen(
[
- f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext",
- f"--build-lib={this_dir}"
+ f"{sys.executable}",
+ f"{os.path.join(this_dir, 'build_c_ext.py')}",
+ "build_ext",
+ f"--build-lib={this_dir}",
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
@@ -294,8 +305,9 @@ class CheckpointSolverRotor(CheckpointSolverBase):
return compute_table(chain, mmax)
@staticmethod
- def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any],
- back_ptr: List[Any]) -> "Sequence":
+ def _backtrack(
+ chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any], back_ptr: List[Any]
+ ) -> "Sequence":
"""Backtrack the cost table and retrieve the optimal checkpointing strategy.
Args:
@@ -328,8 +340,9 @@ class CheckpointSolverRotor(CheckpointSolverBase):
if back_ptr[budget][lhs][rhs][0]:
sequence += [
ForwardEnable(lhs),
- CheckpointSolverRotor._backtrack(chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table,
- back_ptr),
+ CheckpointSolverRotor._backtrack(
+ chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table, back_ptr
+ ),
Backward(lhs),
]
else:
@@ -337,8 +350,9 @@ class CheckpointSolverRotor(CheckpointSolverBase):
sequence += [ForwardCheck(lhs)]
sequence += [ForwardNograd(k) for k in range(lhs + 1, best_leaf)]
sequence += [
- CheckpointSolverRotor._backtrack(chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table,
- back_ptr),
+ CheckpointSolverRotor._backtrack(
+ chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table, back_ptr
+ ),
CheckpointSolverRotor._backtrack(chain, lhs, best_leaf - 1, budget, cost_table, back_ptr),
]
return sequence
@@ -353,8 +367,8 @@ class CheckpointSolverRotor(CheckpointSolverBase):
"""
op_list = sequence.list_operations()
loss_op = next(op for op in op_list if isinstance(op, Loss))
- fwd_list = op_list[:op_list.index(loss_op)]
- bwd_list = op_list[op_list.index(loss_op) + 1:]
+ fwd_list = op_list[: op_list.index(loss_op)]
+ bwd_list = op_list[op_list.index(loss_op) + 1 :]
ckpt_idx = 0
in_ckpt = False
ckpt_region = []
@@ -369,7 +383,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
in_ckpt = False
for node_idx in ckpt_region:
for n in node_list[node_idx]:
- n.meta['activation_checkpoint'] = [ckpt_idx]
+ n.meta["activation_checkpoint"] = [ckpt_idx]
ckpt_idx += 1
ckpt_region = []
@@ -377,7 +391,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
- n.meta['activation_checkpoint'] = [ckpt_idx]
+ n.meta["activation_checkpoint"] = [ckpt_idx]
ckpt_idx += 1
ckpt_region = [idx]
@@ -397,7 +411,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
elif isinstance(op, ForwardEnable):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
- n.meta['activation_checkpoint'].append(ckpt_idx)
+ n.meta["activation_checkpoint"].append(ckpt_idx)
ckpt_idx += 1
ckpt_region = []
@@ -405,7 +419,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
- n.meta['activation_checkpoint'].append(ckpt_idx)
+ n.meta["activation_checkpoint"].append(ckpt_idx)
ckpt_idx += 1
ckpt_region = [op.index]
@@ -413,7 +427,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
elif isinstance(op, Backward):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
- n.meta['activation_checkpoint'].append(ckpt_idx)
+ n.meta["activation_checkpoint"].append(ckpt_idx)
in_recompute = False
@@ -431,9 +445,11 @@ class CheckpointSolverRotor(CheckpointSolverBase):
for node in node_list:
op_list += node
ckpt_regions = _find_nested_ckpt_regions(op_list)
- for (start_idx, end_idx) in ckpt_regions:
+ for start_idx, end_idx in ckpt_regions:
nested_length = max(
- len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1))
+ len(op_list[idx].meta["activation_checkpoint"]) for idx in range(start_idx, end_idx + 1)
+ )
for idx in range(start_idx, end_idx + 1):
- op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length -
- len(op_list[idx].meta['activation_checkpoint']))
+ op_list[idx].meta["activation_checkpoint"] += [None] * (
+ nested_length - len(op_list[idx].meta["activation_checkpoint"])
+ )
diff --git a/colossalai/auto_parallel/checkpoint/operation.py b/colossalai/auto_parallel/checkpoint/operation.py
index ab0c6c5ad38d171d470931aa9b2bdedf6cd17668..5f80779164339930c6943a9132952ee75f988a14 100644
--- a/colossalai/auto_parallel/checkpoint/operation.py
+++ b/colossalai/auto_parallel/checkpoint/operation.py
@@ -1,20 +1,21 @@
import math
from abc import ABC
-from typing import Any, Iterable, List
+from typing import List
from torch.utils._pytree import tree_map
class Chain:
-
- def __init__(self,
- ftime: List[float],
- btime: List[float],
- x: List[int],
- xbar: List[int],
- ftmp: List[int],
- btmp: List[int],
- check_consistency: bool = True):
+ def __init__(
+ self,
+ ftime: List[float],
+ btime: List[float],
+ x: List[int],
+ xbar: List[int],
+ ftmp: List[int],
+ btmp: List[int],
+ check_consistency: bool = True,
+ ):
"""The chain is a basic linearized structure for solving the dynamic programming problem for activation checkpoint.
See paper https://hal.inria.fr/hal-02352969 for details.
@@ -37,9 +38,14 @@ class Chain:
raise AttributeError("In Chain, input lists do not have consistent lengths")
def check_lengths(self):
- return ((len(self.ftime) == len(self)) and (len(self.btime) == len(self) + 1) and (len(self.x) == len(self) + 1)
- and (len(self.ftmp) == len(self)) and (len(self.btmp) == len(self) + 1)
- and (len(self.xbar) == len(self) + 1))
+ return (
+ (len(self.ftime) == len(self))
+ and (len(self.btime) == len(self) + 1)
+ and (len(self.x) == len(self) + 1)
+ and (len(self.ftmp) == len(self))
+ and (len(self.btmp) == len(self) + 1)
+ and (len(self.xbar) == len(self) + 1)
+ )
def __repr__(self):
chain_list = []
@@ -100,7 +106,6 @@ class ForwardCheck(Forward):
class Forwards(Operation):
-
def __init__(self, start, end):
self.index = (start, end)
@@ -109,9 +114,9 @@ class Forwards(Operation):
def cost(self, chain: Chain):
if chain is not None:
- return sum(chain.ftime[self.index[0]:self.index[1] + 1])
+ return sum(chain.ftime[self.index[0] : self.index[1] + 1])
else:
- return (self.index[1] - self.index[0] + 1)
+ return self.index[1] - self.index[0] + 1
def isForward(op):
@@ -132,7 +137,6 @@ class Backward(Operation):
class Loss(Operation):
-
def __init__(self):
pass
@@ -166,7 +170,6 @@ class DiscardMemory(MemoryAccess):
class Sequence(list):
-
def __init__(self):
super().__init__()
diff --git a/colossalai/auto_parallel/meta_profiler/constants.py b/colossalai/auto_parallel/meta_profiler/constants.py
index 35b8c13ee8fff717df39a96c60fa101eb0b2a781..2f638fa919e445ac4bc24c9b8904bfef27e608c5 100644
--- a/colossalai/auto_parallel/meta_profiler/constants.py
+++ b/colossalai/auto_parallel/meta_profiler/constants.py
@@ -3,8 +3,6 @@ import operator
import torch
import torch.nn as nn
-from ..tensor_shard.constants import *
-
# list of inplace module
INPLACE_MODULE = [nn.ReLU]
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py
index 0f2e9e44f91cedfdb888171dd373d4d6163f579f..4234481ae2ca8fa10769d043949d6ecff0d537da 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py
@@ -25,28 +25,32 @@ def elementwise_meta_info(temp_mem_scale: float = 0, buffer_mem_scale: float = 0
def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
input_tensor = next(
filter(
- lambda x:
- (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM) and x.name != 'softmax_dim',
- args)).data
+ lambda x: (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM)
+ and x.name != "softmax_dim",
+ args,
+ )
+ ).data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
- is_inplace = 1 if kwargs.get('inplace', False) else 0
+ is_inplace = 1 if kwargs.get("inplace", False) else 0
flop_counter = elementwise_flop_counter(1, 0)
# calculate compute cost
fwd_compute_cost = flop_counter([input_tensor], [output_tensor])
bwd_compute_cost = flop_counter([output_tensor], [input_tensor])
- compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
- bwd=bwd_compute_cost,
- total=fwd_compute_cost + bwd_compute_cost)
+ compute_cost = TrainCycleItem(
+ fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
+ )
# calculate memory cost
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
# NOTE: if in_place is True, we will not create a new tensor in forward
- fwd_memory_cost = MemoryCost(activation=activation_size(input_tensor) * (2 - is_inplace),
- parameter=0,
- temp=0,
- buffer=activation_size(input_tensor) * buffer_mem_scale)
+ fwd_memory_cost = MemoryCost(
+ activation=activation_size(input_tensor) * (2 - is_inplace),
+ parameter=0,
+ temp=0,
+ buffer=activation_size(input_tensor) * buffer_mem_scale,
+ )
# temp_mem_scale is for situation like softmax backward
# the buffer will be removed during backward phase
@@ -54,20 +58,23 @@ def elementwise_meta_info(temp_mem_scale: float = 0, buffer_mem_scale: float = 0
activation=activation_size(input_tensor) - activation_size(input_tensor) * buffer_mem_scale,
parameter=0,
temp=activation_size(input_tensor) * temp_mem_scale + activation_size(input_tensor) * buffer_mem_scale,
- buffer=0)
+ buffer=0,
+ )
# total cost is the sum of forward and backward cost
- total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
- parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
- temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
- buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer)
+ total_cost = MemoryCost(
+ activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
+ parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
+ temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
+ buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer,
+ )
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
- fwd_buffer = [torch.zeros_like(output_tensor, device='meta')]
- fwd_out = [torch.zeros_like(output_tensor, device='meta')]
+ fwd_buffer = [torch.zeros_like(output_tensor, device="meta")]
+ fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py
index e451748512b9abebcc4f63ad854be3f129ee52bd..0b7b51a719551710041471a95c0d620124abf8ef 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py
@@ -6,10 +6,10 @@ from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
-from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION
+from ..constants import BCAST_FUNC_OP
from ..registry import meta_register
-__all__ = ['binary_elementwise_meta_info']
+__all__ = ["binary_elementwise_meta_info"]
@meta_register.register(BCAST_FUNC_OP)
@@ -61,6 +61,6 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
fwd_buffer = []
- fwd_out = [torch.zeros_like(output_op_data.data, device='meta')]
+ fwd_out = [torch.zeros_like(output_op_data.data, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py
index 4336bf68363c8a708b877c8d8116c44986a85592..2f630995cdbce49dafa2199ce85c31ebdc07b5b7 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py
@@ -1,22 +1,14 @@
-from typing import Callable, Dict, List, Tuple, Union
+from typing import List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- MemoryCost,
- OperationData,
- OperationDataType,
- ShardingStrategy,
- StrategiesVector,
- TrainCycleItem,
-)
-from colossalai.tensor.sharding_spec import ShardingSpec
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from ..registry import meta_register
-__all__ = ['convnd_meta_info']
+__all__ = ["convnd_meta_info"]
@meta_register.register(torch.nn.Conv1d)
@@ -103,35 +95,47 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.convolution.default](fwd_args, (output_tensor,))
- bwd_compute_cost = flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor)) if has_bias else \
- flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor))
+ bwd_compute_cost = (
+ flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor))
+ if has_bias
+ else flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor))
+ )
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
# TODO: use profiler to check conv temp memory
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
- fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
- if has_bias else compute_size_in_bytes(weight_tensor),
- temp=0,
- buffer=0)
-
- bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
- if has_bias else compute_size_in_bytes([input_tensor, weight_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
- if has_bias else compute_size_in_bytes(weight_tensor),
- temp=0,
- buffer=0)
+ fwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, output_tensor]),
+ parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
+ if has_bias
+ else compute_size_in_bytes(weight_tensor),
+ temp=0,
+ buffer=0,
+ )
+
+ bwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
+ if has_bias
+ else compute_size_in_bytes([input_tensor, weight_tensor]),
+ parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
+ if has_bias
+ else compute_size_in_bytes(weight_tensor),
+ temp=0,
+ buffer=0,
+ )
# total cost is the sum of forward and backward cost
- total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
- parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
+ total_cost = MemoryCost(
+ activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
+ parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
+ )
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
- fwd_in = [torch.zeros_like(input_tensor, device='meta')]
+ fwd_in = [torch.zeros_like(input_tensor, device="meta")]
fwd_buffer = []
- fwd_out = [torch.zeros_like(output_tensor, device='meta')]
+ fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py b/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py
index d5d80f5b3700b6b644c0b630496bd907c0b5aac2..7c9add810fd87cef0491f0ce68c3055ebddd5820 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py
@@ -24,8 +24,9 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
# compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.embedding.default]([weight_tensor, input_tensor], [output_tensor])
- bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default]([output_tensor, weight_tensor],
- [weight_tensor])
+ bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default](
+ [output_tensor, weight_tensor], [weight_tensor]
+ )
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
@@ -34,10 +35,9 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
# NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will
# have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume
# that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory
- fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
- parameter=0,
- temp=0,
- buffer=0)
+ fwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, output_tensor]), parameter=0, temp=0, buffer=0
+ )
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([weight_tensor]), parameter=0, temp=0, buffer=0)
total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation)
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py
index 7697fc6c383d8154acfe76dba7d8baec225930ac..d731f9cb4436f0b054beef9a7326a21fea8233a6 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py
@@ -1,23 +1,15 @@
from functools import reduce
-from typing import Callable, Dict, List, Tuple, Union
+from typing import List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- MemoryCost,
- OperationData,
- OperationDataType,
- ShardingStrategy,
- StrategiesVector,
- TrainCycleItem,
-)
-from colossalai.tensor.sharding_spec import ShardingSpec
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from ..registry import meta_register
-__all__ = ['linear_meta_info', 'matmul_meta_info']
+__all__ = ["linear_meta_info", "matmul_meta_info"]
@meta_register.register(torch.nn.functional.linear)
@@ -100,32 +92,43 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.addmm.default](
- [bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
- bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
- flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + \
- flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,))
- compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
- bwd=bwd_compute_cost,
- total=fwd_compute_cost + bwd_compute_cost)
+ [bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)
+ )
+ bwd_compute_cost = (
+ flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,))
+ + flop_mapping[torch.ops.aten.mm.default](
+ [torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)
+ )
+ + flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,))
+ )
+ compute_cost = TrainCycleItem(
+ fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
+ )
# calculate memory cost
# NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
- fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
- temp=0,
- buffer=0)
+ fwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, output_tensor]),
+ parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
+ temp=0,
+ buffer=0,
+ )
# the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0
- bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
- temp=0,
- buffer=0)
+ bwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
+ parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
+ temp=0,
+ buffer=0,
+ )
# total cost is to sum the forward and backward cost
- total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
- parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
+ total_cost = MemoryCost(
+ activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
+ parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
+ )
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
@@ -136,39 +139,49 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
- [input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
- bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
- flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,))
+ [input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)
+ )
+ bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
+ [output_tensor, weight_tensor], (input_tensor,)
+ ) + flop_mapping[torch.ops.aten.mm.default](
+ [torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)
+ )
- compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
- bwd=bwd_compute_cost,
- total=fwd_compute_cost + bwd_compute_cost)
+ compute_cost = TrainCycleItem(
+ fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
+ )
# calculate memory cost
# NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
- fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
- parameter=compute_size_in_bytes(weight_tensor),
- temp=0,
- buffer=0)
+ fwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, output_tensor]),
+ parameter=compute_size_in_bytes(weight_tensor),
+ temp=0,
+ buffer=0,
+ )
# the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0
- bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor]),
- parameter=compute_size_in_bytes(weight_tensor),
- temp=0,
- buffer=0)
+ bwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, weight_tensor]),
+ parameter=compute_size_in_bytes(weight_tensor),
+ temp=0,
+ buffer=0,
+ )
# total cost is to sum the forward and backward cost
- total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
- parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
+ total_cost = MemoryCost(
+ activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
+ parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
+ )
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
- fwd_in = [torch.zeros_like(input_tensor, device='meta')]
+ fwd_in = [torch.zeros_like(input_tensor, device="meta")]
fwd_buffer = []
- fwd_out = [torch.zeros_like(output_tensor, device='meta')]
+ fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
@@ -222,15 +235,16 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# batched gemv case 1: batched matrix-vector multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
- [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors)
+ [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors
+ )
# combine the dimensions of output
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
- [output_tensors[0].reshape(-1), input_tensors[1]],
- output_tensors) + \
- flop_mapping[torch.ops.aten.matmul.default](
- [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
- output_tensors)
+ [output_tensors[0].reshape(-1), input_tensors[1]], output_tensors
+ ) + flop_mapping[torch.ops.aten.matmul.default](
+ [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
+ output_tensors,
+ )
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
@@ -239,93 +253,111 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# gemv case 2: vector-matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors)
- bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \
- flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors)
+ bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
+ [output_tensors[0], input_tensors[0]], output_tensors
+ ) + flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
- bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors),
- parameter=0,
- temp=compute_size_in_bytes(input_tensors[1]),
- buffer=0)
+ bwd_mem_cost = MemoryCost(
+ activation=compute_size_in_bytes(input_tensors),
+ parameter=0,
+ temp=compute_size_in_bytes(input_tensors[1]),
+ buffer=0,
+ )
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3:
# batched gemv case 2: vector-batched matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]],
- [output_tensors[0].reshape(-1)])
+ [output_tensors[0].reshape(-1)],
+ )
# combine the dimensions of output
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
- [output_tensors[0].reshape(-1), input_tensors[0]],
- output_tensors
- ) + \
- flop_mapping[torch.ops.aten.matmul.default](
- [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)],
- output_tensors
- )
+ [output_tensors[0].reshape(-1), input_tensors[0]], output_tensors
+ ) + flop_mapping[torch.ops.aten.matmul.default](
+ [
+ input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1),
+ output_tensors[0].reshape(-1),
+ ],
+ output_tensors,
+ )
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors + [input_tensors[1]]))
- bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
- parameter=0,
- temp=compute_size_in_bytes(input_tensors[1]),
- buffer=0)
+ bwd_mem_cost = MemoryCost(
+ activation=compute_size_in_bytes(input_tensors[0]),
+ parameter=0,
+ temp=compute_size_in_bytes(input_tensors[1]),
+ buffer=0,
+ )
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2:
# gemm & batched gemm case 1: batched matrix-matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]],
- [output_tensors[0].reshape(-1, output_tensors[0].shape[-1])])
+ [output_tensors[0].reshape(-1, output_tensors[0].shape[-1])],
+ )
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
- [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1, output_tensors[0].shape[-1])],
- [input_tensors[1]]
- ) + \
- flop_mapping[torch.ops.aten.mm.default](
- [output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)],
- [input_tensors[0].reshape(-1, input_tensors[0].shape[-1])]
- )
+ [
+ input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1),
+ output_tensors[0].reshape(-1, output_tensors[0].shape[-1]),
+ ],
+ [input_tensors[1]],
+ ) + flop_mapping[torch.ops.aten.mm.default](
+ [output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)],
+ [input_tensors[0].reshape(-1, input_tensors[0].shape[-1])],
+ )
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3:
# batched gemm case 2: matrix-batched matrix multiplication
- fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([
- input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0].transpose(
- 0, 1)
- ], [output_tensors[0].transpose(-2, -1)])
+ fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
+ [
+ input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]),
+ input_tensors[0].transpose(0, 1),
+ ],
+ [output_tensors[0].transpose(-2, -1)],
+ )
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
- [output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1), input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])],
- [input_tensors[0]]
- ) + \
- flop_mapping[torch.ops.aten.mm.default](
- [output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]],
- [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])]
- )
-
- fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors) +
- compute_size_in_bytes(input_tensors[1]),
- temp=compute_size_in_bytes(output_tensors))
- bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
- parameter=0,
- temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors))
+ [
+ output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1),
+ input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]),
+ ],
+ [input_tensors[0]],
+ ) + flop_mapping[torch.ops.aten.mm.default](
+ [output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]],
+ [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])],
+ )
+
+ fwd_mem_cost = MemoryCost(
+ activation=compute_size_in_bytes(output_tensors) + compute_size_in_bytes(input_tensors[1]),
+ temp=compute_size_in_bytes(output_tensors),
+ )
+ bwd_mem_cost = MemoryCost(
+ activation=compute_size_in_bytes(input_tensors[0]),
+ parameter=0,
+ temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors),
+ )
elif all(len(tensor.shape) >= 3 for tensor in input_tensors):
# Batched matrix-batched matrix multiplication
# Fetch shape of the two inputs and see if the batch dimensions are the same
_is_batch_dims_same = True
if len(input_tensors[0].shape) == len(input_tensors[1].shape):
- for (shape_0, shape_1) in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]):
+ for shape_0, shape_1 in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]):
if shape_0 != shape_1:
_is_batch_dims_same = False
break
else:
_is_batch_dims_same = False
- # retireve dimensions
+ # retrieve dimensions
input_dim_00 = input_tensors[0].shape[-2]
input_dim_01 = input_tensors[0].shape[-1]
input_dim_10 = input_tensors[1].shape[-2]
@@ -337,20 +369,28 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# Case 1: batch dimensions are the same
# Forward compute cost: C = A * B
- fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]([
- input_tensors[0].reshape(-1, input_dim_00, input_dim_01), input_tensors[1].reshape(
- -1, input_dim_10, input_dim_11)
- ], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
+ fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
+ [
+ input_tensors[0].reshape(-1, input_dim_00, input_dim_01),
+ input_tensors[1].reshape(-1, input_dim_10, input_dim_11),
+ ],
+ [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
+ )
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
- [input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
- [input_tensors[1].reshape(-1, input_dim_11, input_dim_10)]
- ) + \
- flop_mapping[torch.ops.aten.bmm.default](
- [output_tensors[0].reshape(-1, output_dim_0, output_dim_1), input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10)],
- [input_tensors[0].reshape(-1, input_dim_00, input_dim_01)]
- )
+ [
+ input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00),
+ output_tensors[0].reshape(-1, output_dim_0, output_dim_1),
+ ],
+ [input_tensors[1].reshape(-1, input_dim_11, input_dim_10)],
+ ) + flop_mapping[torch.ops.aten.bmm.default](
+ [
+ output_tensors[0].reshape(-1, output_dim_0, output_dim_1),
+ input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10),
+ ],
+ [input_tensors[0].reshape(-1, input_dim_00, input_dim_01)],
+ )
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors))
@@ -358,43 +398,46 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
else:
# Case 2: batch dimensions are different
batch_dims = output_tensors[0].shape[:-2]
- extended_input_0 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
- input_dim_00,
- input_dim_01,
- device="meta")
- extended_input_1 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
- input_dim_10,
- input_dim_11,
- device="meta")
+ extended_input_0 = torch.rand(
+ reduce(lambda x, y: x * y, batch_dims), input_dim_00, input_dim_01, device="meta"
+ )
+ extended_input_1 = torch.rand(
+ reduce(lambda x, y: x * y, batch_dims), input_dim_10, input_dim_11, device="meta"
+ )
# Forward compute cost: C = A * B
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
- [extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
+ [extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)]
+ )
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
- [extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
- [extended_input_1]
- ) + \
- flop_mapping[torch.ops.aten.bmm.default](
- [output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)],
- [extended_input_0]
- )
+ [extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
+ [extended_input_1],
+ ) + flop_mapping[torch.ops.aten.bmm.default](
+ [output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)],
+ [extended_input_0],
+ )
fwd_mem_cost = MemoryCost(
- activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1]))
- bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors) -
- compute_size_in_bytes([extended_input_0, extended_input_1]),
- temp=compute_size_in_bytes([extended_input_0, extended_input_1]))
+ activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1])
+ )
+ bwd_mem_cost = MemoryCost(
+ activation=compute_size_in_bytes(input_tensors)
+ - compute_size_in_bytes([extended_input_0, extended_input_1]),
+ temp=compute_size_in_bytes([extended_input_0, extended_input_1]),
+ )
# compute cost
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# memory cost
- total_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
- parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
- temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
- buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
+ total_cost = MemoryCost(
+ activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
+ parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
+ temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
+ buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer,
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_cost)
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py b/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py
index 12874810b13e252c0597e2adf124ab7875e992a3..b1bb1d872c3549e8804976b03397b385963ecfc8 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py
@@ -3,7 +3,7 @@ from typing import List, Tuple
import torch
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from ..registry import meta_register
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py
index b872fdc8bdcd19717e7b81d436fffd860ec88519..99aaa752d0a1b05b543a03eccfa19d5d15d096e7 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py
@@ -1,22 +1,14 @@
-from typing import Callable, Dict, List, Tuple, Union
+from typing import List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- MemoryCost,
- OperationData,
- OperationDataType,
- ShardingStrategy,
- StrategiesVector,
- TrainCycleItem,
-)
-from colossalai.tensor.sharding_spec import ShardingSpec
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from ..registry import meta_register
-__all__ = ['batchnormnd_meta_info', 'layernorm_meta_info']
+__all__ = ["batchnormnd_meta_info", "layernorm_meta_info"]
@meta_register.register(torch.nn.BatchNorm1d)
@@ -65,7 +57,15 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
# saved inv std and some other args indicating the status of the module
# the bwd outputs are input grad, weight grad and bias grad
bwd_in_args = [
- output_tensor, output_tensor, weight_tensor, mean_tensor, var_tensor, mean_tensor, var_tensor, 1e-5, num_batch
+ output_tensor,
+ output_tensor,
+ weight_tensor,
+ mean_tensor,
+ var_tensor,
+ mean_tensor,
+ var_tensor,
+ 1e-5,
+ num_batch,
]
bwd_out_args = [input_tensor, weight_tensor, bias_tensor]
@@ -77,29 +77,34 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
# calculate memory cost
# the fwd activation cost is output plus saved mean and saved inv std
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
- fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes(
- [input_tensor, output_tensor, mean_tensor, var_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
- temp=0,
- buffer=compute_size_in_bytes([mean_tensor, var_tensor]))
+ fwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, output_tensor, mean_tensor, var_tensor]),
+ parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
+ temp=0,
+ buffer=compute_size_in_bytes([mean_tensor, var_tensor]),
+ )
# the bwd memory cost is quite tricky here, BatchNorm will remove saved mean
# and saved inv std during backward phase
- bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
- temp=compute_size_in_bytes([mean_tensor, var_tensor]),
- buffer=compute_size_in_bytes([mean_tensor, var_tensor]))
+ bwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor]),
+ parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
+ temp=compute_size_in_bytes([mean_tensor, var_tensor]),
+ buffer=compute_size_in_bytes([mean_tensor, var_tensor]),
+ )
# total cost is the sum of forward and backward cost
- total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
- parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
+ total_cost = MemoryCost(
+ activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
+ parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
+ )
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
- fwd_in = [torch.zeros_like(input_tensor, device='meta')]
- fwd_buffer = [torch.zeros_like(mean_tensor, device='meta'), torch.zeros_like(var_tensor, device='meta')]
- fwd_out = [torch.zeros_like(output_tensor, device='meta')]
+ fwd_in = [torch.zeros_like(input_tensor, device="meta")]
+ fwd_buffer = [torch.zeros_like(mean_tensor, device="meta"), torch.zeros_like(var_tensor, device="meta")]
+ fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
@@ -116,8 +121,8 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
weight_tensor = next(filter(lambda x: x.name == "weight", args)).data
bias_tensor = next(filter(lambda x: x.name == "bias", args)).data
- running_mean = torch.rand(input_tensor.shape[0], 1, device='meta')
- running_var = torch.rand(input_tensor.shape[0], 1, device='meta')
+ running_mean = torch.rand(input_tensor.shape[0], 1, device="meta")
+ running_var = torch.rand(input_tensor.shape[0], 1, device="meta")
# construct args
fwd_in_args = [input_tensor, [input_tensor.shape[0]], weight_tensor]
@@ -132,27 +137,32 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
# memory cost
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
- fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes(
- [input_tensor, output_tensor, weight_tensor, bias_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
- temp=0,
- buffer=compute_size_in_bytes([running_mean, running_var]))
-
- bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
- parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
- temp=compute_size_in_bytes([running_mean, running_var]),
- buffer=compute_size_in_bytes([running_mean, running_var]))
-
- total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
- parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
- temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
- buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer)
+ fwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, output_tensor, weight_tensor, bias_tensor]),
+ parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
+ temp=0,
+ buffer=compute_size_in_bytes([running_mean, running_var]),
+ )
+
+ bwd_memory_cost = MemoryCost(
+ activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
+ parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
+ temp=compute_size_in_bytes([running_mean, running_var]),
+ buffer=compute_size_in_bytes([running_mean, running_var]),
+ )
+
+ total_cost = MemoryCost(
+ activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
+ parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
+ temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
+ buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer,
+ )
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
- fwd_in = [torch.zeros_like(input_tensor, device='meta')]
- fwd_buffer = [torch.zeros_like(running_mean, device='meta'), torch.zeros_like(running_var, device='meta')]
- fwd_out = [torch.zeros_like(output_tensor, device='meta')]
+ fwd_in = [torch.zeros_like(input_tensor, device="meta")]
+ fwd_buffer = [torch.zeros_like(running_mean, device="meta"), torch.zeros_like(running_var, device="meta")]
+ fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py
index d785dfcca9bacb46e129adda5f83486090975859..21aa524bed084f4684ac5414e7b8bf19c78048c7 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py
@@ -63,7 +63,7 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
fwd_buffer = []
- fwd_out = [torch.zeros_like(output_tensor, device='meta')]
+ fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out
@@ -117,8 +117,10 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor, index_matrix]))
# temp memory for backward is the index matrix to be discarded
- bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix),
- temp=compute_size_in_bytes(index_matrix))
+ bwd_mem_cost = MemoryCost(
+ activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix),
+ temp=compute_size_in_bytes(index_matrix),
+ )
# total cost
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, temp=bwd_mem_cost.temp)
@@ -126,8 +128,8 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store fwd_in, fwd_buffer, fwd_out
- fwd_in = [torch.zeros_like(input_tensor, device='meta')]
- fwd_buffer = [torch.zeros_like(index_matrix, device='meta')]
- fwd_out = [torch.zeros_like(output_tensor, device='meta')]
+ fwd_in = [torch.zeros_like(input_tensor, device="meta")]
+ fwd_buffer = [torch.zeros_like(index_matrix, device="meta")]
+ fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py b/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py
index 97fe3c6196f591af7bbfcbdcf59ff3afd114175f..9a2df1bd7c870fbff108c41e1a943c20acb43c63 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py
@@ -2,7 +2,6 @@ from typing import Callable, List, Tuple
import torch
-from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
@@ -37,15 +36,19 @@ def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: f
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * 2, parameter=0, temp=0, buffer=0)
- bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor,
- parameter=0,
- temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor,
- buffer=0)
+ bwd_mem_cost = MemoryCost(
+ activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor,
+ parameter=0,
+ temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor,
+ buffer=0,
+ )
- total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
- parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
- temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
- buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
+ total_mem_cost = MemoryCost(
+ activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
+ parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
+ temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
+ buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer,
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
@@ -66,14 +69,24 @@ def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: f
# register torch.Tensor related metainfo
# (0, 0)
-meta_register.register([torch.tensor, torch.Tensor.to, torch.Tensor.unsqueeze, torch.unsqueeze,
- torch.arange])(tensor_related_metainfo(0, 0))
+meta_register.register([torch.tensor, torch.Tensor.to, torch.Tensor.unsqueeze, torch.unsqueeze, torch.arange])(
+ tensor_related_metainfo(0, 0)
+)
# (1, 0)
-meta_register.register([
- torch.Tensor.flatten, torch.flatten, torch.Tensor.transpose, torch.transpose, torch.Tensor.permute, torch.permute,
- torch.Tensor.split, torch.split, torch.Tensor.view
-])(tensor_related_metainfo(1, 0))
+meta_register.register(
+ [
+ torch.Tensor.flatten,
+ torch.flatten,
+ torch.Tensor.transpose,
+ torch.transpose,
+ torch.Tensor.permute,
+ torch.permute,
+ torch.Tensor.split,
+ torch.split,
+ torch.Tensor.view,
+ ]
+)(tensor_related_metainfo(1, 0))
# (1, 1)
meta_register.register([torch.Tensor.type, torch.Tensor.contiguous])(tensor_related_metainfo(1, 1))
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/where.py b/colossalai/auto_parallel/meta_profiler/meta_registry/where.py
index 5cba1b5b6e2b16521ed2a0df2fbab98b19492c53..107851b80d7c1a45744d0817dc36f18928eb1444 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/where.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/where.py
@@ -4,7 +4,7 @@ import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from ..registry import meta_register
@@ -39,16 +39,21 @@ def where_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Li
# gradient matrix for input x and input y, remove the temp memory and condition tensor generated in forward phase
# NOTE: currently in SPMD solver we always believe that there will be a new input tensor created in forward
fwd_mem_cost = MemoryCost(activation=activation_size([condition_tensor, x_tensor, y_tensor, output_tensor]))
- bwd_mem_cost = MemoryCost(activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]),
- parameter=0,
- temp=activation_size([output_tensor]) * 3 + activation_size([condition_tensor]) -
- activation_size([x_tensor, y_tensor]),
- buffer=0)
-
- total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
- parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
- temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
- buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
+ bwd_mem_cost = MemoryCost(
+ activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]),
+ parameter=0,
+ temp=activation_size([output_tensor]) * 3
+ + activation_size([condition_tensor])
+ - activation_size([x_tensor, y_tensor]),
+ buffer=0,
+ )
+
+ total_mem_cost = MemoryCost(
+ activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
+ parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
+ temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
+ buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer,
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
diff --git a/colossalai/auto_parallel/meta_profiler/registry.py b/colossalai/auto_parallel/meta_profiler/registry.py
index 46350c4dd406691c344eb92a933636d6b029b8bd..c29086f7f9d16fdcae11d56f10a1f2ac63a82e5b 100644
--- a/colossalai/auto_parallel/meta_profiler/registry.py
+++ b/colossalai/auto_parallel/meta_profiler/registry.py
@@ -1,14 +1,12 @@
-__all__ = ['Registry']
+__all__ = ["Registry"]
class Registry:
-
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
-
def wrapper(func):
if isinstance(source, (list, tuple)):
# support register a list of items for this func
@@ -21,7 +19,7 @@ class Registry:
return wrapper
def get(self, source):
- assert source in self.store, f'{source} not found in the {self.name} registry'
+ assert source in self.store, f"{source} not found in the {self.name} registry"
target = self.store[source]
return target
@@ -29,4 +27,4 @@ class Registry:
return source in self.store
-meta_register = Registry('meta')
+meta_register = Registry("meta")
diff --git a/colossalai/auto_parallel/meta_profiler/shard_metainfo.py b/colossalai/auto_parallel/meta_profiler/shard_metainfo.py
index 0eee908b48b73d9d1cd5e0e35fbf50b8d844e3a6..109b8a220ac760174883a38726efe87d8a6cfb52 100644
--- a/colossalai/auto_parallel/meta_profiler/shard_metainfo.py
+++ b/colossalai/auto_parallel/meta_profiler/shard_metainfo.py
@@ -2,20 +2,13 @@ from typing import Callable, List
import torch
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- MemoryCost,
- OperationData,
- OperationDataType,
- ShardingStrategy,
- StrategiesVector,
- TrainCycleItem,
-)
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, ShardingStrategy, TrainCycleItem
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION
from .registry import meta_register
-__all__ = ['ShardMetaInfo']
+__all__ = ["ShardMetaInfo"]
class ShardMetaInfo:
@@ -76,10 +69,12 @@ class ShardMetaInfo:
"""
if isinstance(sharding_spec, ShardingSpec):
- op_data = OperationData(name=operation_data.name,
- data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
- type=operation_data.type,
- logical_shape=operation_data.logical_shape)
+ op_data = OperationData(
+ name=operation_data.name,
+ data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
+ type=operation_data.type,
+ logical_shape=operation_data.logical_shape,
+ )
elif isinstance(sharding_spec, (list, tuple)):
data = operation_data.data
assert isinstance(data, (list, tuple)), f"Data Should be list or tuple, but got {type(data)}."
@@ -97,8 +92,9 @@ class ShardMetaInfo:
"""
Compute meta info based on sharding strategy and the given target function.
"""
- assert meta_register.has(self._target.__class__) or meta_register.has(self._target), \
- f"Meta info for {self._target} is not registered."
+ assert meta_register.has(self._target.__class__) or meta_register.has(
+ self._target
+ ), f"Meta info for {self._target} is not registered."
if meta_register.has(self._target.__class__):
# module
meta_func = meta_register.get(self._target.__class__)
@@ -117,11 +113,11 @@ class ShardMetaInfo:
# construct kwargs
if self.target in INPLACE_MODULE:
- kwargs = {'inplace': self.target.inplace}
+ kwargs = {"inplace": self.target.inplace}
elif self.target in INPLACE_OPS:
- kwargs = {'inplace': True}
+ kwargs = {"inplace": True}
else:
- kwargs = {'inplace': False}
+ kwargs = {"inplace": False}
# compute metainfo with meta_func
self.compute_cost, self.memory_cost, self.fwd_in, self.fwd_buffer, self.fwd_out = meta_func(*args, **kwargs)
diff --git a/colossalai/auto_parallel/offload/amp_optimizer.py b/colossalai/auto_parallel/offload/amp_optimizer.py
index a79e5006e7d2ac264f00b5a597e7b869f6f580eb..601bf2926d991a4f8c844f69ffbc6e09f551e6f3 100644
--- a/colossalai/auto_parallel/offload/amp_optimizer.py
+++ b/colossalai/auto_parallel/offload/amp_optimizer.py
@@ -1,24 +1,25 @@
-from typing import Dict, Tuple
from enum import Enum
+from typing import Dict, Tuple
+
import torch
from torch.optim import Optimizer
-from colossalai.logging import get_dist_logger
-from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
+from colossalai.interface import OptimizerWrapper
+from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from .base_offload_module import BaseOffloadModule
-from .region_manager import RegionManager
from .region import Region
+from .region_manager import RegionManager
class OptimState(Enum):
SCALED = 0
UNSCALED = 1
-class AMPOptimizer(ColossalaiOptimizer):
+class AMPOptimizer(OptimizerWrapper):
"""
A wrapper for Optimizer.
Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py
@@ -36,19 +37,20 @@ class AMPOptimizer(ColossalaiOptimizer):
norm_type (float, optional): norm_type used for `clip_grad_norm`.
"""
- def __init__(self,
- optimizer: Optimizer,
- module: BaseOffloadModule,
- initial_scale: float = 2**16,
- growth_factor: float = 2,
- backoff_factor: float = 0.5,
- growth_interval: int = 1000,
- hysteresis: int = 2,
- min_scale: float = 1,
- max_scale: float = 2**32,
- clipping_norm: float = 0.0,
- norm_type: float = 2.0):
-
+ def __init__(
+ self,
+ optimizer: Optimizer,
+ module: BaseOffloadModule,
+ initial_scale: float = 2**16,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ min_scale: float = 1,
+ max_scale: float = 2**32,
+ clipping_norm: float = 0.0,
+ norm_type: float = 2.0,
+ ):
super().__init__(optimizer)
self.module = module
@@ -68,19 +70,21 @@ class AMPOptimizer(ColossalaiOptimizer):
self.__init__optimizer()
# Grad scaler
- self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
- min_scale=min_scale,
- growth_factor=growth_factor,
- backoff_factor=backoff_factor,
- growth_interval=growth_interval,
- hysteresis=hysteresis,
- max_scale=max_scale)
+ self.grad_scaler = DynamicGradScaler(
+ initial_scale=initial_scale,
+ min_scale=min_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ hysteresis=hysteresis,
+ max_scale=max_scale,
+ )
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
self._logger = get_dist_logger()
def _set_grad_ptr(self):
for group in self.param_groups:
- for fake_param in group['params']:
+ for fake_param in group["params"]:
region = self.param_to_region[fake_param]
begin, end = self.param_to_range[fake_param]
@@ -91,7 +95,7 @@ class AMPOptimizer(ColossalaiOptimizer):
def _update_fp16_params(self):
none_tensor = torch.empty([0])
for group in self.param_groups:
- for fake_param in group['params']:
+ for fake_param in group["params"]:
assert fake_param.grad is None
fake_param.data = none_tensor
self.param_to_region[fake_param].cpu_grad = None
@@ -129,10 +133,10 @@ class AMPOptimizer(ColossalaiOptimizer):
found_inf = self._check_overflow()
if found_inf:
- self.optim_state = OptimState.UNSCALED # no need to unscale grad
- self.grad_scaler.update(found_inf) # update gradient scaler
- self._logger.info(f'Found overflow. Skip step')
- self.zero_grad() # reset all gradients
+ self.optim_state = OptimState.UNSCALED # no need to unscale grad
+ self.grad_scaler.update(found_inf) # update gradient scaler
+ self._logger.info(f"Found overflow. Skip step")
+ self.zero_grad() # reset all gradients
self._update_fp16_params()
return
@@ -155,11 +159,10 @@ class AMPOptimizer(ColossalaiOptimizer):
self.module.backward(loss)
def __init__optimizer(self):
-
for group in self.optim.param_groups:
fake_params_list = list()
- for param in group['params']:
+ for param in group["params"]:
region = self.region_manager.get_region(param)
fake_param = torch.nn.Parameter(torch.empty([0]))
self.param_to_range[fake_param] = region.param_to_range[param]
@@ -170,8 +173,8 @@ class AMPOptimizer(ColossalaiOptimizer):
if param in self.optim.state:
self.optim.state[fake_param] = self.optim.state.pop(param)
- group['params'] = fake_params_list
+ group["params"] = fake_params_list
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
- self.optim.load_state_dict(self.optim.state_dict())
\ No newline at end of file
+ self.optim.load_state_dict(self.optim.state_dict())
diff --git a/colossalai/auto_parallel/offload/base_offload_module.py b/colossalai/auto_parallel/offload/base_offload_module.py
index d0c328e134ff5696ea2f8c17d5fc3468e4c891a2..f5e8e31f5e9798cb777c8ad8d9e5ad298d0fcb8b 100644
--- a/colossalai/auto_parallel/offload/base_offload_module.py
+++ b/colossalai/auto_parallel/offload/base_offload_module.py
@@ -4,7 +4,7 @@ from typing import Optional, Set
import torch
import torch.nn as nn
-from colossalai.nn.parallel.data_parallel import _cast_float
+from colossalai.utils import _cast_float
from colossalai.zero.legacy.gemini.tensor_utils import free_storage
from .region_manager import RegionManager
@@ -22,7 +22,6 @@ class BaseOffloadModule:
"""
def __init__(self, model: nn.Module, region_manager: RegionManager, is_sync=True):
-
self.model = model
self.region_manager = region_manager
self.grad_hook_list = []
@@ -91,17 +90,16 @@ class BaseOffloadModule:
def parameters(self, recurse: bool = True):
return self.model.parameters(recurse)
- def named_parameters(self, prefix: str = '', recurse: bool = True):
+ def named_parameters(self, prefix: str = "", recurse: bool = True):
return self.model.named_parameters(prefix, recurse)
- def named_buffers(self, prefix: str = '', recurse: bool = True):
+ def named_buffers(self, prefix: str = "", recurse: bool = True):
return self.model.named_buffers(prefix, recurse)
def named_children(self):
return self.model.named_children()
- def named_modules(self,
- memo: Optional[Set[torch.nn.Module]] = None,
- prefix: str = '',
- remove_duplicate: bool = True):
+ def named_modules(
+ self, memo: Optional[Set[torch.nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
+ ):
return self.model.named_modules(memo, prefix, remove_duplicate)
diff --git a/colossalai/auto_parallel/offload/mem_optimize.py b/colossalai/auto_parallel/offload/mem_optimize.py
index d56166dea982288bdea160e1347c8ca3f67ed297..74501c18451845c1f62c81ccbb757766928005b5 100644
--- a/colossalai/auto_parallel/offload/mem_optimize.py
+++ b/colossalai/auto_parallel/offload/mem_optimize.py
@@ -14,11 +14,9 @@ from .runtime import runtime_asyn_offload_apply_pass, runtime_syn_offload_apply_
from .util import GlobalRuntimeInfo, compute_act_peak_mem, compute_max_param_mem, compute_total_param_mem
-def memory_optimize(model: torch.nn.Module,
- inps: Dict[str, torch.Tensor],
- memory_budget: float = -1.0,
- solver_name: str = 'asyn'):
-
+def memory_optimize(
+ model: torch.nn.Module, inps: Dict[str, torch.Tensor], memory_budget: float = -1.0, solver_name: str = "asyn"
+):
model = model.cpu().half()
tracer = ColoTracer()
assert is_compatible_with_meta()
@@ -40,13 +38,13 @@ def memory_optimize(model: torch.nn.Module,
f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}"
)
- if solver_name == 'syn':
+ if solver_name == "syn":
gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list)
- elif solver_name == 'asyn':
+ elif solver_name == "asyn":
gm = runtime_asyn_offload_apply_pass(gm, region_manager.region_list)
else:
raise TypeError(f"Unknown solver name {solver_name}!")
gm.recompile()
- optimized_model = BaseOffloadModule(gm, region_manager, solver_name == 'syn')
+ optimized_model = BaseOffloadModule(gm, region_manager, solver_name == "syn")
return optimized_model
diff --git a/colossalai/auto_parallel/offload/region.py b/colossalai/auto_parallel/offload/region.py
index 819ffbd96eb19098f168519ca6e3e0036fa3a638..ea92c714ce31d185b2da86963cada7e1bb2a4ac4 100644
--- a/colossalai/auto_parallel/offload/region.py
+++ b/colossalai/auto_parallel/offload/region.py
@@ -55,13 +55,13 @@ class Region:
Map the parameters in the region to a contiguous memory space.
"""
- self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device='cuda')
+ self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device="cuda")
offset = 0
for param in self.fp16_params:
param.data = param.data.cuda()
p_num = param.data.numel()
- self.fp16_data[offset:offset + p_num].copy_(param.data.flatten())
- param.data = self.fp16_data[offset:offset + p_num].view(param.data.shape)
+ self.fp16_data[offset : offset + p_num].copy_(param.data.flatten())
+ param.data = self.fp16_data[offset : offset + p_num].view(param.data.shape)
self.param_to_range[param] = (offset, offset + p_num)
offset += p_num
@@ -83,7 +83,7 @@ class Region:
self.temp_fp32_data.record_stream(torch.cuda.current_stream())
if not self.in_mem_pool_flag:
alloc_storage(self.fp16_data)
- self.fp16_data[:self.param_num].copy_(self.temp_fp32_data)
+ self.fp16_data[: self.param_num].copy_(self.temp_fp32_data)
self.fp16_data.record_stream(torch.cuda.current_stream())
self.__update_params_ptr()
@@ -94,7 +94,7 @@ class Region:
"""
self.cpu_grad = torch.empty(self.param_num, dtype=torch.half, pin_memory=True)
- self.cpu_grad.copy_(self.fp16_data[:self.param_num], non_blocking=True)
+ self.cpu_grad.copy_(self.fp16_data[: self.param_num], non_blocking=True)
self.fp16_data.record_stream(torch.cuda.current_stream())
if not self.in_mem_pool_flag:
self.free_cuda_data()
diff --git a/colossalai/auto_parallel/offload/region_manager.py b/colossalai/auto_parallel/offload/region_manager.py
index 30bfaf00d4939afadc3c2aaaec3f27ce70db4a20..146dd267967d8b26c16d950245381c699e86c5ca 100644
--- a/colossalai/auto_parallel/offload/region_manager.py
+++ b/colossalai/auto_parallel/offload/region_manager.py
@@ -1,10 +1,11 @@
-from typing import List, Any, Dict, Tuple
+from typing import Any, Dict, List, Tuple
+
import torch
from torch.fx import Graph, Node
+from .region import Region
from .solver import SolverFactory
from .training_simulator import TrainingSimulator
-from .region import Region
from .util import NodeInfo
@@ -19,14 +20,9 @@ class RegionManager:
cnode (List[str], optional): Common node List, should be the subset of input.
"""
- def __init__(self,
- graph: Graph,
- solver_name: str = 'asyn',
- memory_budget: float = -1.0,
- cnode: List[str] = None):
-
+ def __init__(self, graph: Graph, solver_name: str = "asyn", memory_budget: float = -1.0, cnode: List[str] = None):
self.graph = graph
- assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
+ assert graph.owning_module is not None, "The given graph is not associated with a owning_module"
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.cnode = cnode
@@ -39,7 +35,7 @@ class RegionManager:
self.memory_budget = memory_budget
self.solver_name = solver_name
- self.require_pool: bool = solver_name == 'asyn'
+ self.require_pool: bool = solver_name == "asyn"
self.reg_to_block: Dict[int, int] = dict()
@@ -61,22 +57,19 @@ class RegionManager:
self._post_process(solver.best_ts)
def _pre_process(self):
-
init_region_list = self._linearize_graph()
if len(self.shared_region_pairs) > 1:
- raise NotImplementedError(
- 'The current version only considers at most one pair of parameter sharing.')
+ raise NotImplementedError("The current version only considers at most one pair of parameter sharing.")
elif len(self.shared_region_pairs) == 1:
shared_regs = self.shared_region_pairs[0]
- assert shared_regs[0].shared_rid == shared_regs[1].r_id \
- and shared_regs[1].shared_rid == shared_regs[0].r_id
+ assert shared_regs[0].shared_rid == shared_regs[1].r_id and shared_regs[1].shared_rid == shared_regs[0].r_id
fst_id = shared_regs[0].r_id
lst_id = shared_regs[1].r_id
- regs_left_out = init_region_list[:fst_id + 1]
+ regs_left_out = init_region_list[: fst_id + 1]
regs_right_out = init_region_list[lst_id:]
- hold_regs = init_region_list[fst_id + 1:lst_id]
+ hold_regs = init_region_list[fst_id + 1 : lst_id]
else:
regs_left_out = []
regs_right_out = []
@@ -122,12 +115,9 @@ class RegionManager:
it may not find a suitable region placement strategy for the given execution flow.
"""
- reg_flow = torch.cat(
- [ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0)
- mem_block_num = torch.max(
- torch.sum(reg_flow[:, self.rid_in_pool], dim=1))
- coexist_matrix = torch.logical_or(
- ts.fwd_reg_flow, ts.bwd_reg_flow)
+ reg_flow = torch.cat([ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0)
+ mem_block_num = torch.max(torch.sum(reg_flow[:, self.rid_in_pool], dim=1))
+ coexist_matrix = torch.logical_or(ts.fwd_reg_flow, ts.bwd_reg_flow)
block_to_regs = {}
for block_idx in range(mem_block_num):
@@ -135,8 +125,7 @@ class RegionManager:
for reg in self.region_list:
if reg.r_id in self.rid_in_pool:
cur_reg_appears = coexist_matrix[:, reg.r_id]
- cur_reg_coexists = torch.sum(
- coexist_matrix[cur_reg_appears], dim=0).bool()
+ cur_reg_coexists = torch.sum(coexist_matrix[cur_reg_appears], dim=0).bool()
for block_idx in range(mem_block_num):
if not any(cur_reg_coexists[block_to_regs[block_idx]]):
block_to_regs[block_idx].append(reg.r_id)
@@ -145,9 +134,12 @@ class RegionManager:
if reg.r_id not in self.reg_to_block:
raise NotImplementedError(
- f'can not find a block from the memory pool to store parameters of the region')
- self.memory_pool = torch.chunk(torch.zeros(int(
- mem_block_num * self.mem_block_size / 2), dtype=torch.half, device='cuda'), chunks=int(mem_block_num))
+ f"can not find a block from the memory pool to store parameters of the region"
+ )
+ self.memory_pool = torch.chunk(
+ torch.zeros(int(mem_block_num * self.mem_block_size / 2), dtype=torch.half, device="cuda"),
+ chunks=int(mem_block_num),
+ )
def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]:
"""
@@ -178,10 +170,9 @@ class RegionManager:
return region_list
- def _search_block_size(self,
- region_list: List[Region],
- search_interval_byte: int = 1024,
- search_range_byte: int = 128 * 1024 ** 2) -> int:
+ def _search_block_size(
+ self, region_list: List[Region], search_interval_byte: int = 1024, search_range_byte: int = 128 * 1024**2
+ ) -> int:
"""
Search for a suitable memory block size.
@@ -208,11 +199,10 @@ class RegionManager:
acc_wasted += blk_size - left
return acc_wasted
- param_size_list = [
- region.param_size for region in region_list if region.r_id == region.shared_rid]
+ param_size_list = [region.param_size for region in region_list if region.r_id == region.shared_rid]
start_size = max(param_size_list)
- min_mem_waste = float('+inf')
+ min_mem_waste = float("+inf")
best_block_size = start_size
for block_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
@@ -229,7 +219,7 @@ class RegionManager:
Initialize region data, which maps the parameters in the region to a contiguous memory space.
"""
- self.temp_fp32_data = torch.zeros(self.max_param_num, device='cuda', dtype=torch.float32)
+ self.temp_fp32_data = torch.zeros(self.max_param_num, device="cuda", dtype=torch.float32)
for region in self.region_list:
pre_alloc_tensor = None
@@ -244,8 +234,7 @@ class RegionManager:
region.fp16_data = shared_region.fp16_data
region.fp32_data = shared_region.fp32_data
region.param_to_range = shared_region.param_to_range
- region.temp_fp32_data = self.temp_fp32_data[:region.param_num].detach(
- )
+ region.temp_fp32_data = self.temp_fp32_data[: region.param_num].detach()
torch.cuda.empty_cache()
@@ -259,13 +248,14 @@ class RegionManager:
former_reg, latter_reg = self.shared_region_pairs[0]
assert latter_reg.param_num >= former_reg.param_num
embedding_node = former_reg.nodes[-1]
- assert embedding_node.op == 'call_module' and isinstance(
- self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding)
+ assert embedding_node.op == "call_module" and isinstance(
+ self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding
+ )
if latter_reg.param_num > former_reg.param_num:
for idx, n in enumerate(latter_reg.nodes):
- if (n.op == 'call_module' and isinstance(self.root_module.get_submodule(n.target),
- torch.nn.Linear)) or \
- (n.op == 'call_function' and n.target is torch.nn.functional.linear):
+ if (
+ n.op == "call_module" and isinstance(self.root_module.get_submodule(n.target), torch.nn.Linear)
+ ) or (n.op == "call_function" and n.target is torch.nn.functional.linear):
cut_node_idx = idx + 1
break
assert len(latter_reg.fp16_params) == 2
@@ -273,7 +263,7 @@ class RegionManager:
for p in new_reg.fp16_params:
self.param_region_map[p] = new_reg
self.region_list.insert(new_reg.r_id, new_reg)
- for reg in self.region_list[new_reg.r_id + 1:]:
+ for reg in self.region_list[new_reg.r_id + 1 :]:
reg.r_id += 1
latter_reg.shared_rid = former_reg.r_id
former_reg.shared_rid = latter_reg.r_id
@@ -344,8 +334,8 @@ class RegionManager:
target = n.target
submod = self.root_module.get_submodule(target)
if (
- len(list(submod.named_parameters(recurse=False))) != 0
- or len(list(submod.named_buffers(recurse=False))) != 0
+ len(list(submod.named_parameters(recurse=False))) != 0
+ or len(list(submod.named_buffers(recurse=False))) != 0
):
label = True
@@ -362,14 +352,12 @@ class RegionManager:
"""
def _is_inplace(n: Node):
- """Get the inplace argument from ``torch.fx.Node``
- """
+ """Get the inplace argument from ``torch.fx.Node``"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
elif n.op == "call_module":
- inplace = getattr(n.graph.owning_module.get_submodule(
- n.target), "inplace", False)
+ inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
return inplace
label = False
@@ -378,28 +366,30 @@ class RegionManager:
target = n.target
submod = self.root_module.get_submodule(target)
if (
- len(list(submod.named_parameters(recurse=False))) != 0
- or len(list(submod.named_buffers(recurse=False))) != 0
+ len(list(submod.named_parameters(recurse=False))) != 0
+ or len(list(submod.named_buffers(recurse=False))) != 0
):
label = True
elif n.op == "call_function":
label = any(map(lambda x: x.name in self.only_param_ops, n.all_input_nodes)) and any(
- map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes))
+ map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes)
+ )
return label and not sum([v for _, v in param_op_deps.items()]) and not any(map(_is_inplace, n.users))
def _exception_node_handling():
# TODO meta info prop bug
- if n.name.__contains__("transpose") and n.meta['fwd_out'][0].dim() <= 2:
- n.meta['fwd_out'] = []
+ if n.name.__contains__("transpose") and n.meta["fwd_out"][0].dim() <= 2:
+ n.meta["fwd_out"] = []
# make sure that item in cnode is valid
if self.cnode:
for name in self.cnode:
try:
- assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \
- f"Common node {name} is not an input of the model."
+ assert (
+ next(node for node in self.graph.nodes if node.name == name).op == "placeholder"
+ ), f"Common node {name} is not an input of the model."
except StopIteration:
raise ValueError(f"Common node name {name} not in graph.")
else:
@@ -428,8 +418,8 @@ class RegionManager:
ns = []
border_n_idx = region.nodes.index(act_n)
if border_n_idx < len(region.nodes):
- ns = region.nodes[border_n_idx + 1:]
- region.nodes = region.nodes[:border_n_idx + 1]
+ ns = region.nodes[border_n_idx + 1 :]
+ region.nodes = region.nodes[: border_n_idx + 1]
region_list.append(region)
region_id += 1
region = Region(r_id=region_id)
@@ -448,19 +438,21 @@ class RegionManager:
region = Region(r_id=region_id)
# propagate common node attr if possible
- if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
- ]) or _is_cop(n.target):
+ if len(n.all_input_nodes) == len(
+ [node for node in n.all_input_nodes if node.name in self.cnode]
+ ) or _is_cop(n.target):
self.cnode.append(n.name)
else:
- deps[n] = len(
- [user for user in n.users if user.op != "output"])
+ deps[n] = len([user for user in n.users if user.op != "output"])
# propagate param node attr if possible
- if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.only_param_ops
- ]) or n.op == "get_attr":
+ if (
+ len(n.all_input_nodes)
+ == len([node for node in n.all_input_nodes if node.name in self.only_param_ops])
+ or n.op == "get_attr"
+ ):
self.only_param_ops.append(n.name)
- param_op_deps[n] = len(
- [user for user in n.users if user.op != "output"])
+ param_op_deps[n] = len([user for user in n.users if user.op != "output"])
# record last activation node
if _is_act(n._meta_data):
@@ -472,19 +464,16 @@ class RegionManager:
return region_list
def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region):
-
cur_n.node_info = NodeInfo(node_id)
- if cur_n.op == 'call_module':
+ if cur_n.op == "call_module":
target = cur_n.target
submod = self.root_module.get_submodule(target)
for p in list(submod.parameters(recurse=False)):
-
if p in self.param_region_map:
cur_reg.shared_rid = self.param_region_map[p].r_id
self.param_region_map[p].shared_rid = cur_reg.r_id
- self.shared_region_pairs.append(
- (self.param_region_map[p], cur_reg))
+ self.shared_region_pairs.append((self.param_region_map[p], cur_reg))
else:
self.param_region_map[p] = cur_reg
@@ -499,12 +488,10 @@ class RegionManager:
attr_itr = getattr(attr_itr, atom)
if isinstance(attr_itr, torch.nn.Parameter):
-
if attr_itr in self.param_region_map:
cur_reg.shared_rid = self.param_region_map[attr_itr].r_id
self.param_region_map[attr_itr].shared_rid = cur_reg.r_id
- self.shared_region_pairs.append(
- (self.param_region_map[attr_itr], cur_reg))
+ self.shared_region_pairs.append((self.param_region_map[attr_itr], cur_reg))
else:
self.param_region_map[attr_itr] = cur_reg
diff --git a/colossalai/auto_parallel/offload/runtime.py b/colossalai/auto_parallel/offload/runtime.py
index 764ac608826b2860f3294ec9a5883745dfa10f57..cc790dfb089129a6f2cb101753f3d6431052b8a0 100644
--- a/colossalai/auto_parallel/offload/runtime.py
+++ b/colossalai/auto_parallel/offload/runtime.py
@@ -22,13 +22,13 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, fwd_info, bwd_info):
ctx.bwd_info = bwd_info
- d2h_rid = fwd_info.get('d2h_rid', None)
+ d2h_rid = fwd_info.get("d2h_rid", None)
if d2h_rid is not None:
free_region = GlobalRuntimeInfo().region_list[d2h_rid]
assert isinstance(free_region, Region)
free_region.free_cuda_data()
- h2d_rid = fwd_info.get('h2d_rid', None)
+ h2d_rid = fwd_info.get("h2d_rid", None)
if h2d_rid is not None:
h2d_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(h2d_region, Region)
@@ -38,8 +38,7 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
-
- h2d_rid = ctx.bwd_info.get('h2d_rid', None)
+ h2d_rid = ctx.bwd_info.get("h2d_rid", None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region)
@@ -64,13 +63,13 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
def forward(ctx, input_, fwd_info, bwd_info):
ctx.bwd_info = bwd_info
- sync_rid = fwd_info.get('sync_rid', None)
+ sync_rid = fwd_info.get("sync_rid", None)
if sync_rid is not None:
prefetch_event = GlobalRuntimeInfo().fwd_prefetch_event_map.get(sync_rid, None)
if prefetch_event:
prefetch_event.wait()
- h2d_rid = fwd_info.get('h2d_rid', None)
+ h2d_rid = fwd_info.get("h2d_rid", None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region)
@@ -87,8 +86,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
-
- sync_rid = ctx.bwd_info.get('sync_rid', None)
+ sync_rid = ctx.bwd_info.get("sync_rid", None)
if sync_rid is not None:
wait_region = GlobalRuntimeInfo().region_list[sync_rid]
assert isinstance(wait_region, Region)
@@ -98,7 +96,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
else:
wait_region.move_param_to_cuda()
- h2d_rid = ctx.bwd_info.get('h2d_rid', None)
+ h2d_rid = ctx.bwd_info.get("h2d_rid", None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region)
@@ -114,7 +112,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
- '''
+ """
Convert Upload and Offload operation into runtime action.
Argument:
@@ -123,14 +121,14 @@ def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
that need to be uploaded, or freed during forward pass.
bwd_info(dict): information dict, which contains region indices
that need to be uploaded during backward pass.
- '''
+ """
with torch._C.DisableTorchFunction():
ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
return ret
def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info):
- '''
+ """
Convert Prefetch and Offload operation into runtime action.
Argument:
@@ -139,7 +137,7 @@ def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info):
that need to be prefetched, waited, or freed during forward pass.
bwd_info(dict): information dict, which contains region indices
that need to be prefetched or waited during backward pass.
- '''
+ """
with torch._C.DisableTorchFunction():
ret = AsynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
return ret
@@ -176,22 +174,22 @@ def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[R
# forward upload
fwd_info = {}
if requires_upload_p_in_fwd(region_list[region.shared_rid]):
- fwd_info['h2d_rid'] = region.r_id
+ fwd_info["h2d_rid"] = region.r_id
# forward offload
if r_idx > 0 and region_list[r_idx - 1].need_offload:
- fwd_info['d2h_rid'] = r_idx - 1
+ fwd_info["d2h_rid"] = r_idx - 1
bwd_info = {}
# backward upload
if r_idx > 0 and region_list[r_idx - 1].need_offload:
- bwd_info['h2d_rid'] = region_list[r_idx - 1].r_id
+ bwd_info["h2d_rid"] = region_list[r_idx - 1].r_id
if fwd_info or bwd_info:
with mod_graph.inserting_after(last_inp_node):
- new_node = mod_graph.create_node('call_function',
- convert_fwd_upload_bwd_offload_to_action,
- args=(last_inp_node, fwd_info, bwd_info))
+ new_node = mod_graph.create_node(
+ "call_function", convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, bwd_info)
+ )
replace_node_users(last_inp_node, new_node)
last_inp_node = region.nodes[-1]
@@ -210,9 +208,9 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
first_region_with_p = [region for region in region_list if region.param_size][0]
fwd_info = {"h2d_rid": first_region_with_p.r_id}
with mod_graph.inserting_after(last_inp_node):
- upload_apply_node = mod_graph.create_node('call_function',
- convert_fwd_upload_bwd_offload_to_action,
- args=(last_inp_node, fwd_info, {}))
+ upload_apply_node = mod_graph.create_node(
+ "call_function", convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, {})
+ )
replace_node_users(last_inp_node, upload_apply_node)
last_inp_node = upload_apply_node
@@ -220,37 +218,39 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
# forward prefetch
fwd_info = {}
if region.param_size:
- fwd_info['sync_rid'] = region.r_id
+ fwd_info["sync_rid"] = region.r_id
fwd_prefetch_region = region.fwd_prefetch_region
if fwd_prefetch_region and requires_upload_p_in_fwd(region_list[fwd_prefetch_region.shared_rid]):
- fwd_info['h2d_rid'] = fwd_prefetch_region.r_id
+ fwd_info["h2d_rid"] = fwd_prefetch_region.r_id
# forward offload
if r_idx > 0 and region_list[r_idx - 1].need_offload:
- fwd_info['d2h_rid'] = r_idx - 1
+ fwd_info["d2h_rid"] = r_idx - 1
bwd_info = {}
# backward prefetch
if r_idx > 0 and region_list[r_idx - 1].need_offload:
- bwd_info['sync_rid'] = r_idx - 1
+ bwd_info["sync_rid"] = r_idx - 1
if r_idx > 0 and region_list[r_idx - 1].bwd_prefetch_region:
- bwd_info['h2d_rid'] = region_list[r_idx - 1].bwd_prefetch_region.r_id
+ bwd_info["h2d_rid"] = region_list[r_idx - 1].bwd_prefetch_region.r_id
if fwd_info or bwd_info:
with mod_graph.inserting_after(last_inp_node):
- new_node = mod_graph.create_node('call_function',
- convert_fwd_prefetch_bwd_offload_to_action,
- args=(last_inp_node, fwd_info, bwd_info))
+ new_node = mod_graph.create_node(
+ "call_function",
+ convert_fwd_prefetch_bwd_offload_to_action,
+ args=(last_inp_node, fwd_info, bwd_info),
+ )
replace_node_users(last_inp_node, new_node)
last_inp_node = region.nodes[-1]
if region.bwd_prefetch_region:
- bwd_info = {'h2d_rid': region.bwd_prefetch_region.r_id}
+ bwd_info = {"h2d_rid": region.bwd_prefetch_region.r_id}
with mod_graph.inserting_after(last_inp_node):
- new_node = mod_graph.create_node('call_function',
- convert_fwd_prefetch_bwd_offload_to_action,
- args=(last_inp_node, {}, bwd_info))
+ new_node = mod_graph.create_node(
+ "call_function", convert_fwd_prefetch_bwd_offload_to_action, args=(last_inp_node, {}, bwd_info)
+ )
replace_node_users(last_inp_node, new_node)
# gm.graph.print_tabular()
return gm
diff --git a/colossalai/auto_parallel/offload/solver.py b/colossalai/auto_parallel/offload/solver.py
index 161f7ff868981d913047d7073bb144694ab68591..a6b4904f2617eff77216dfaf3a4fb5fd2449e061 100644
--- a/colossalai/auto_parallel/offload/solver.py
+++ b/colossalai/auto_parallel/offload/solver.py
@@ -1,6 +1,6 @@
import time
-from typing import List, Dict, Type
from abc import ABC, abstractmethod
+from typing import Dict, List, Type
NOT_NVML = False
try:
@@ -10,10 +10,11 @@ except:
import torch
from torch.fx.node import Node
+
from colossalai.utils.cuda import get_current_device
-from .training_simulator import TrainingSimulator, SynTrainingSimulator, AsynTrainingSimulator
from .region import Region
+from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator
from .util import NodeInfo, NvDevicePower
@@ -49,19 +50,14 @@ class Solver(ABC):
It is used to reduce the memory budget. Due to some errors in the estimation of peak memory and execution time.
"""
- def __init__(self,
- region_list: List[Region],
- memory_budget: float = -1.0,
- error_factor: float = 0.95) -> None:
-
+ def __init__(self, region_list: List[Region], memory_budget: float = -1.0, error_factor: float = 0.95) -> None:
self.region_list = region_list
self.error_factor: float = error_factor
if memory_budget > 0:
self.memory_budget = memory_budget * self.error_factor
else:
- self.memory_budget = torch.cuda.get_device_properties(
- get_current_device()).total_memory * self.error_factor
+ self.memory_budget = torch.cuda.get_device_properties(get_current_device()).total_memory * self.error_factor
self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth()
self.comp_power: float = self._extract_computing_power()
@@ -94,7 +90,7 @@ class Solver(ABC):
if extra_cost == 0:
# means data transfer overhead can be completely overlapped
- return (float('inf'), total_mem_saving, peak_mem_saving)
+ return (float("inf"), total_mem_saving, peak_mem_saving)
return (total_mem_saving / extra_cost, total_mem_saving, peak_mem_saving)
def _compare_profit(self, profit_a: tuple, profit_b: tuple) -> bool:
@@ -122,9 +118,7 @@ class Solver(ABC):
self.best_ts = best_ts
self._update_node_mem_info(best_ts.fwd_node_mem, best_ts.bwd_node_mem)
- def _update_node_mem_info(self,
- fwd_mem_info: Dict[Node, float],
- bwd_mem_info: Dict[Node, float]):
+ def _update_node_mem_info(self, fwd_mem_info: Dict[Node, float], bwd_mem_info: Dict[Node, float]):
"""
Update the runtime memory information of the node.
@@ -134,12 +128,10 @@ class Solver(ABC):
"""
for node, mem in fwd_mem_info.items():
- assert hasattr(node, 'node_info') and isinstance(
- node.node_info, NodeInfo)
+ assert hasattr(node, "node_info") and isinstance(node.node_info, NodeInfo)
node.node_info.runtime_fwd_mem = mem
for node, mem in bwd_mem_info.items():
- assert hasattr(node, 'node_info') and isinstance(
- node.node_info, NodeInfo)
+ assert hasattr(node, "node_info") and isinstance(node.node_info, NodeInfo)
node.node_info.runtime_bwd_mem = mem
def _extract_computing_power(self):
@@ -159,12 +151,12 @@ class Solver(ABC):
return NvDevicePower.RTX3080_FP16 * units
elif device_name.__contains__("RTX 3090"):
return NvDevicePower.RTX3090_FP16 * units
- elif device_name.__contains__('V100'):
+ elif device_name.__contains__("V100"):
return NvDevicePower.V100_FP16 * units
elif device_name.__contains__("A100"):
return NvDevicePower.A100_FP16 * units
else:
- raise TypeError(f'Unknown NVIDIA GPU device name {device_name}')
+ raise TypeError(f"Unknown NVIDIA GPU device name {device_name}")
def _profile_bandwidth(self):
"""
@@ -172,9 +164,9 @@ class Solver(ABC):
using data volumes ranging from 1KB to 1GB.
"""
- print('profiling bandwidth ......')
+ print("profiling bandwidth ......")
link_to_bandwidth = {}
- links = ['h2d', 'd2h']
+ links = ["h2d", "d2h"]
for link in links:
t_size = 1024
@@ -182,24 +174,22 @@ class Solver(ABC):
# from 1KB to 1GB
for i in range(21):
- if link == 'h2d':
- src_tensor = torch.ones(
- int(t_size), dtype=torch.int8, pin_memory=True)
- dst_tensor = torch.ones(
- (int(t_size)), dtype=torch.int8, device='cuda')
- elif link == 'd2h':
- src_tensor = torch.ones(
- int(t_size), dtype=torch.int8, device='cuda')
- dst_tensor = torch.ones(
- (int(t_size)), dtype=torch.int8, pin_memory=True)
+ if link == "h2d":
+ src_tensor = torch.ones(int(t_size), dtype=torch.int8, pin_memory=True)
+ dst_tensor = torch.ones((int(t_size)), dtype=torch.int8, device="cuda")
+ elif link == "d2h":
+ src_tensor = torch.ones(int(t_size), dtype=torch.int8, device="cuda")
+ dst_tensor = torch.ones((int(t_size)), dtype=torch.int8, pin_memory=True)
def func():
dst_tensor.copy_(src_tensor)
size_to_bandwidth[t_size] = t_size / benchmark_func(func, number=5, repeat=3)
- print(f'size: {t_size / 1024 ** 2:.3f} MB, '
- f'{src_tensor.device.type}-to-{dst_tensor.device.type} '
- f'bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s')
+ print(
+ f"size: {t_size / 1024 ** 2:.3f} MB, "
+ f"{src_tensor.device.type}-to-{dst_tensor.device.type} "
+ f"bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s"
+ )
t_size *= 2
@@ -208,10 +198,7 @@ class Solver(ABC):
class SynGreedySolver(Solver):
-
- def __init__(self,
- region_list: List[Region],
- memory_budget: float = -1.0) -> None:
+ def __init__(self, region_list: List[Region], memory_budget: float = -1.0) -> None:
super().__init__(region_list, memory_budget)
self.best_ts: SynTrainingSimulator = None
@@ -258,7 +245,8 @@ class SynGreedySolver(Solver):
else:
raise NotImplementedError(
f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, "
- f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!")
+ f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!"
+ )
def _call_solver_l2l(self):
"""
@@ -270,7 +258,6 @@ class SynGreedySolver(Solver):
region.is_syn = True
def _try_to_offload(self, offload_region: Region):
-
# record previous information
orig_need_offload = offload_region.need_offload
assert not orig_need_offload
@@ -297,23 +284,17 @@ class SynGreedySolver(Solver):
ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
ts.execute()
- extra_comm_cost = 2.0 * \
- ts._get_communication_overhead('h2d', offload_region.param_size)
+ extra_comm_cost = 2.0 * ts._get_communication_overhead("h2d", offload_region.param_size)
# the shared region needs to be moved twice
if offload_region.r_id < offload_region.shared_rid:
extra_comm_cost *= 2.0
- profit = self._compute_offload_profit(
- ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
+ profit = self._compute_offload_profit(ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
return ts, profit
class AsynGreedySolver(Solver):
-
- def __init__(self,
- region_list: List[Region],
- memory_budget: float = -1.0,
- search_window_size: int = 3):
+ def __init__(self, region_list: List[Region], memory_budget: float = -1.0, search_window_size: int = 3):
super().__init__(region_list, memory_budget)
self.search_window_size = search_window_size
@@ -331,7 +312,7 @@ class AsynGreedySolver(Solver):
ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
ts.execute()
self._update_state(ts)
- print("init peak memory", self.best_ts.peak_mem / 1024 ** 2, "MB")
+ print("init peak memory", self.best_ts.peak_mem / 1024**2, "MB")
def _call_solver(self):
"""
@@ -358,18 +339,17 @@ class AsynGreedySolver(Solver):
best_pref_ts = None
# search when to prefetch the region offloaded
- for host_region in self.region_list[region.r_id + 1:region.r_id + 1 + self.search_window_size]:
+ for host_region in self.region_list[region.r_id + 1 : region.r_id + 1 + self.search_window_size]:
if host_region.bwd_prefetch_region is not None:
continue
- temp_ts, profit = self._try_to_offload(
- host_region, region)
+ temp_ts, profit = self._try_to_offload(host_region, region)
if self._compare_profit(profit, max_prefetch_profit):
region_to_region_map[region.r_id] = host_region
max_prefetch_profit = profit
best_pref_ts = temp_ts
- if profit[0] == float('inf'):
+ if profit[0] == float("inf"):
break
if self._compare_profit(max_prefetch_profit, max_offload_profit):
@@ -392,7 +372,8 @@ class AsynGreedySolver(Solver):
else:
raise NotImplementedError(
f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, "
- f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!")
+ f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!"
+ )
region_to_region_map.clear()
@@ -452,7 +433,6 @@ class AsynGreedySolver(Solver):
peak_mem_saving = 0
while len(self.region_to_region_map) and peak_mem_saving <= 0:
-
max_profit = (0,)
best_ts = None
undo_host_region = None
@@ -464,8 +444,7 @@ class AsynGreedySolver(Solver):
assert offload_region.need_offload
assert not offload_region.is_syn
- ts, profit = self._try_convert_to_syn_upload(host_region,
- offload_region)
+ ts, profit = self._try_convert_to_syn_upload(host_region, offload_region)
if self._compare_profit(profit, max_profit):
undo_host_region = host_region
@@ -474,7 +453,7 @@ class AsynGreedySolver(Solver):
best_ts = ts
if best_ts is None:
- raise NotImplementedError('repair error!')
+ raise NotImplementedError("repair error!")
assert not undo_offload_region.is_syn
undo_offload_region.is_syn = True
@@ -500,17 +479,13 @@ class AsynGreedySolver(Solver):
ts.execute()
extra_comm_cost = max(ts.iter_end_time - self.best_ts.iter_end_time, 0)
- profit = self._compute_offload_profit(
- ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
+ profit = self._compute_offload_profit(ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
return ts, profit
class SolverFactory:
- solvers: Dict[str, Type[Solver]] = {
- 'syn': SynGreedySolver,
- 'asyn': AsynGreedySolver
- }
+ solvers: Dict[str, Type[Solver]] = {"syn": SynGreedySolver, "asyn": AsynGreedySolver}
@staticmethod
def create(solver_name: str) -> Type[Solver]:
diff --git a/colossalai/auto_parallel/offload/training_simulator.py b/colossalai/auto_parallel/offload/training_simulator.py
index de58023ec2d6a4b4247ad838f4e9c2e9a56da692..728d8daf9a467ef393c2291bbc9c0fbbe674e579 100644
--- a/colossalai/auto_parallel/offload/training_simulator.py
+++ b/colossalai/auto_parallel/offload/training_simulator.py
@@ -1,7 +1,7 @@
import bisect
-from typing import List, Dict
-from collections import OrderedDict
from abc import ABC, abstractmethod
+from collections import OrderedDict
+from typing import Dict, List
from torch.fx.node import Node
@@ -26,10 +26,7 @@ class TrainingSimulator(ABC):
link_to_bw (Dict[str, Dict[float, float]]): communication links and the corresponding bandwidth.
"""
- def __init__(self,
- region_list: List[Region],
- comp_power: float,
- link_to_bw: Dict[str, Dict[float, float]]) -> None:
+ def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None:
self.region_list = region_list
self.region_num = len(region_list)
@@ -87,11 +84,7 @@ class TrainingSimulator(ABC):
class SynTrainingSimulator(TrainingSimulator):
-
- def __init__(self,
- region_list: List[Region],
- comp_power: float,
- link_to_bw: Dict[str, Dict[float, float]]) -> None:
+ def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None:
super().__init__(region_list, comp_power, link_to_bw)
def execute(self):
@@ -115,8 +108,7 @@ class SynTrainingSimulator(TrainingSimulator):
self.runtime_mem += region.param_size
for node in region.nodes:
- self.runtime_mem += calculate_fwd_tmp(node) + \
- calculate_fwd_out(node)
+ self.runtime_mem += calculate_fwd_tmp(node) + calculate_fwd_out(node)
self.fwd_node_mem[node] = self.runtime_mem
self.peak_mem = max(self.runtime_mem, self.peak_mem)
self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem
@@ -141,18 +133,15 @@ class SynTrainingSimulator(TrainingSimulator):
self.runtime_mem += region.param_size
for node in region.nodes.__reversed__():
-
self.runtime_mem -= calculate_fwd_out(node)
- self.runtime_mem += node.meta['bwd_mem_tmp'] + \
- node.meta['bwd_mem_out']
+ self.runtime_mem += node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
self.peak_mem = max(self.runtime_mem, self.peak_mem)
# The memory savings of a node may be negative due to parameter prefetch.
self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem
self.bwd_node_mem[node] = self.runtime_mem
- self.runtime_mem -= (node.meta['bwd_mem_tmp'] +
- calculate_fwd_tmp(node))
+ self.runtime_mem -= node.meta["bwd_mem_tmp"] + calculate_fwd_tmp(node)
# free bwd_mem_out
self.bwd_node_deps[node] = len(node.all_input_nodes)
@@ -160,12 +149,14 @@ class SynTrainingSimulator(TrainingSimulator):
if user_node in self.bwd_node_deps:
self.bwd_node_deps[user_node] -= 1
if self.bwd_node_deps[user_node] <= 0:
- self.runtime_mem -= user_node.meta['bwd_mem_out']
+ self.runtime_mem -= user_node.meta["bwd_mem_out"]
if self.runtime_mem < 0:
- raise ValueError(f"region id: {region.r_id}, node name: {node.name}, "
- f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
- f"runtime memory computed less than 0, which is miscalculated!")
+ raise ValueError(
+ f"region id: {region.r_id}, node name: {node.name}, "
+ f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
+ f"runtime memory computed less than 0, which is miscalculated!"
+ )
# release parameter and offload gradient in region
if region.r_id == region.shared_rid:
@@ -177,23 +168,16 @@ class SynTrainingSimulator(TrainingSimulator):
class AsynTrainingSimulator(TrainingSimulator):
-
- def __init__(self,
- region_list: List[Region],
- comp_power: float,
- link_to_bw: Dict[str, Dict[float, float]]) -> None:
+ def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None:
super().__init__(region_list, comp_power, link_to_bw)
self.iter_end_time: int = 0
# the last computation execution period
- self.last_comp: ExecutionPeriod = ExecutionPeriod(
- start_time=0, end_time=0)
+ self.last_comp: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0)
# the last parameter prefetch execution period
- self.last_h2d: ExecutionPeriod = ExecutionPeriod(
- start_time=0, end_time=0)
+ self.last_h2d: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0)
# the last gradient offload execution period
- self.last_d2h: ExecutionPeriod = ExecutionPeriod(
- start_time=0, end_time=0)
+ self.last_d2h: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0)
# the forward computation execution period of the region
self.fwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the forward parameter prefetch execution period of the region
@@ -204,10 +188,8 @@ class AsynTrainingSimulator(TrainingSimulator):
self.bwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the gradient offload execution period of the region
# which is divided into those that are waiting and those that have been released
- self.bwd_reg_to_offl_waiting: OrderedDict[int,
- ExecutionPeriod] = OrderedDict()
- self.bwd_reg_to_offl_freed: OrderedDict[int,
- ExecutionPeriod] = OrderedDict()
+ self.bwd_reg_to_offl_waiting: OrderedDict[int, ExecutionPeriod] = OrderedDict()
+ self.bwd_reg_to_offl_freed: OrderedDict[int, ExecutionPeriod] = OrderedDict()
# the region buffer, which records regions that are offloaded but not released
self.reg_buffer_to_free: List[int] = []
@@ -217,10 +199,8 @@ class AsynTrainingSimulator(TrainingSimulator):
# the region execution flow,
# where fwd_reg_flow[i,j] denotes whether the parameters of j-th region are in the GPU
# when the execution reaches the i-th region.
- self.fwd_reg_flow = torch.zeros(
- (self.region_num, self.region_num)).bool()
- self.bwd_reg_flow = torch.zeros(
- (self.region_num, self.region_num)).bool()
+ self.fwd_reg_flow = torch.zeros((self.region_num, self.region_num)).bool()
+ self.bwd_reg_flow = torch.zeros((self.region_num, self.region_num)).bool()
def execute(self):
"""
@@ -232,7 +212,7 @@ class AsynTrainingSimulator(TrainingSimulator):
for reg in self.region_list:
if reg.param_size and reg.r_id < self.region_num - 1:
- for nr in self.region_list[reg.r_id + 1:]:
+ for nr in self.region_list[reg.r_id + 1 :]:
if nr.param_size and requires_upload_p_in_fwd(self.region_list[nr.shared_rid]):
reg.fwd_prefetch_region = nr
break
@@ -249,8 +229,7 @@ class AsynTrainingSimulator(TrainingSimulator):
self.runtime_mem -= self.region_list[reg_id].param_size
self.bwd_reg_to_offl_waiting.clear()
- self.iter_end_time = max(
- self.last_comp.end_time, self.last_d2h.end_time)
+ self.iter_end_time = max(self.last_comp.end_time, self.last_d2h.end_time)
def _insert_h2d_exec(self, region: Region, is_fwd: bool = True):
"""
@@ -258,10 +237,8 @@ class AsynTrainingSimulator(TrainingSimulator):
"""
pref_start_time = max(self.last_h2d.end_time, self.last_comp.end_time)
- pref_end_time = pref_start_time + \
- 2.0 * self._get_communication_overhead('h2d', region.param_size)
- pref_ep = ExecutionPeriod(
- start_time=pref_start_time, end_time=pref_end_time)
+ pref_end_time = pref_start_time + 2.0 * self._get_communication_overhead("h2d", region.param_size)
+ pref_ep = ExecutionPeriod(start_time=pref_start_time, end_time=pref_end_time)
if is_fwd:
self.fwd_reg_to_pref[region.r_id] = pref_ep
else:
@@ -276,18 +253,16 @@ class AsynTrainingSimulator(TrainingSimulator):
if is_fwd:
reg_to_comp = self.fwd_reg_to_comp
reg_to_pref = self.fwd_reg_to_pref
- flop_key = 'fwd_flop'
+ flop_key = "fwd_flop"
else:
reg_to_comp = self.bwd_reg_to_comp
reg_to_pref = self.bwd_reg_to_pref
- flop_key = 'bwd_flop'
- comp_start_time = max(self.last_comp.end_time, reg_to_pref.get(
- region.r_id, ExecutionPeriod(0, 0)).end_time)
- comp_end_time = comp_start_time + \
- sum([self._get_computing_overhead(node.meta.get(flop_key, 0))
- for node in region.nodes])
- comp_ep = ExecutionPeriod(
- start_time=comp_start_time, end_time=comp_end_time)
+ flop_key = "bwd_flop"
+ comp_start_time = max(self.last_comp.end_time, reg_to_pref.get(region.r_id, ExecutionPeriod(0, 0)).end_time)
+ comp_end_time = comp_start_time + sum(
+ [self._get_computing_overhead(node.meta.get(flop_key, 0)) for node in region.nodes]
+ )
+ comp_ep = ExecutionPeriod(start_time=comp_start_time, end_time=comp_end_time)
reg_to_comp[region.r_id] = comp_ep
self.last_comp = comp_ep
@@ -297,10 +272,8 @@ class AsynTrainingSimulator(TrainingSimulator):
"""
offl_start_time = max(self.last_d2h.end_time, self.last_comp.end_time)
- offl_end_time = offl_start_time + \
- self._get_communication_overhead('d2h', region.param_size)
- offl_ep = ExecutionPeriod(
- start_time=offl_start_time, end_time=offl_end_time)
+ offl_end_time = offl_start_time + self._get_communication_overhead("d2h", region.param_size)
+ offl_ep = ExecutionPeriod(start_time=offl_start_time, end_time=offl_end_time)
self.bwd_reg_to_offl_waiting[region.r_id] = offl_ep
self.last_d2h = offl_ep
@@ -332,20 +305,17 @@ class AsynTrainingSimulator(TrainingSimulator):
self.fwd_reg_flow[region.r_id, region.r_id] = True
else:
self.fwd_reg_flow[region.r_id] = self.fwd_reg_flow[region.r_id - 1]
- self.fwd_reg_flow[region.r_id,
- self.reg_buffer_to_free] = False
+ self.fwd_reg_flow[region.r_id, self.reg_buffer_to_free] = False
self.reg_buffer_to_free.clear()
# prefetch parameters of the next region
fwd_prefetch_region = region.fwd_prefetch_region
if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]):
self.runtime_mem += fwd_prefetch_region.param_size
- self.fwd_reg_flow[region.r_id,
- fwd_prefetch_region.r_id] = True
+ self.fwd_reg_flow[region.r_id, fwd_prefetch_region.r_id] = True
for node in region.nodes:
- self.runtime_mem += calculate_fwd_tmp(node) + \
- calculate_fwd_out(node)
+ self.runtime_mem += calculate_fwd_tmp(node) + calculate_fwd_out(node)
self.peak_mem = max(self.runtime_mem, self.peak_mem)
self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem
@@ -354,8 +324,7 @@ class AsynTrainingSimulator(TrainingSimulator):
if region.need_offload:
self.runtime_mem -= region.param_size
- assert len(
- self.reg_buffer_to_free) <= 1, f'{len(self.reg_buffer_to_free)}'
+ assert len(self.reg_buffer_to_free) <= 1, f"{len(self.reg_buffer_to_free)}"
self.reg_buffer_to_free.append(region.r_id)
def _eval_bwd_cost_per_region(self, region: Region):
@@ -398,8 +367,7 @@ class AsynTrainingSimulator(TrainingSimulator):
self.bwd_reg_flow[region.r_id] = self.bwd_reg_flow[region.r_id + 1]
else:
self.bwd_reg_flow[region.r_id] = self.fwd_reg_flow[-1]
- self.bwd_reg_flow[region.r_id,
- self.reg_buffer_to_free] = False
+ self.bwd_reg_flow[region.r_id, self.reg_buffer_to_free] = False
# free gradients in the buffer
while len(self.reg_buffer_to_free):
@@ -415,8 +383,7 @@ class AsynTrainingSimulator(TrainingSimulator):
bwd_prefetch_region = region.bwd_prefetch_region
if bwd_prefetch_region:
self.runtime_mem += bwd_prefetch_region.param_size
- self.bwd_reg_flow[region.r_id,
- bwd_prefetch_region.r_id] = True
+ self.bwd_reg_flow[region.r_id, bwd_prefetch_region.r_id] = True
# add the gradient of the parameter
if region.r_id < region.shared_rid:
@@ -426,10 +393,8 @@ class AsynTrainingSimulator(TrainingSimulator):
self.runtime_mem += region.param_size
for node in region.nodes.__reversed__():
-
self.runtime_mem -= calculate_fwd_out(node)
- self.runtime_mem += node.meta['bwd_mem_tmp'] + \
- node.meta['bwd_mem_out']
+ self.runtime_mem += node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
self.peak_mem = max(self.runtime_mem, self.peak_mem)
# The memory savings of a node may be negative due to parameter prefetch.
@@ -437,8 +402,7 @@ class AsynTrainingSimulator(TrainingSimulator):
self.bwd_node_mem[node] = self.runtime_mem
- self.runtime_mem -= (node.meta['bwd_mem_tmp'] +
- calculate_fwd_tmp(node))
+ self.runtime_mem -= node.meta["bwd_mem_tmp"] + calculate_fwd_tmp(node)
# free bwd_mem_out
self.bwd_node_deps[node] = len(node.all_input_nodes)
@@ -446,12 +410,14 @@ class AsynTrainingSimulator(TrainingSimulator):
if user_node in self.bwd_node_deps:
self.bwd_node_deps[user_node] -= 1
if self.bwd_node_deps[user_node] <= 0:
- self.runtime_mem -= user_node.meta['bwd_mem_out']
+ self.runtime_mem -= user_node.meta["bwd_mem_out"]
if self.runtime_mem < 0:
- raise ValueError(f"region id: {region.r_id}, node name: {node.name}, "
- f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
- f"runtime memory computed less than 0, which is miscalculated!")
+ raise ValueError(
+ f"region id: {region.r_id}, node name: {node.name}, "
+ f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
+ f"runtime memory computed less than 0, which is miscalculated!"
+ )
# release parameters of the region
if requires_release_p_in_bwd(self.region_list[region.shared_rid]):
diff --git a/colossalai/auto_parallel/offload/util.py b/colossalai/auto_parallel/offload/util.py
index 6b010512cc9c99b6dff4acf2e7bd97d12d8146c0..cb65da79c5a27ac369367d72d6429fa4e02cb62a 100644
--- a/colossalai/auto_parallel/offload/util.py
+++ b/colossalai/auto_parallel/offload/util.py
@@ -35,7 +35,6 @@ class NvDevicePower:
class GlobalRuntimeInfo(metaclass=SingletonMeta):
-
def __init__(self):
self.h2d_stream = torch.cuda.Stream()
self.d2h_stream = torch.cuda.Stream()
@@ -50,21 +49,18 @@ def compute_act_peak_mem(region_list: List[Region]) -> float:
# forward
for region in region_list:
for node in region.nodes:
- runtime_mem = runtime_mem + \
- calculate_fwd_tmp(node) + calculate_fwd_out(node)
+ runtime_mem = runtime_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node)
act_peak_mem = max(runtime_mem, act_peak_mem)
# backward
bwd_deps = {}
for region in region_list.__reversed__():
for node in region.nodes.__reversed__():
runtime_mem -= calculate_fwd_out(node)
- runtime_mem = runtime_mem + \
- node.meta['bwd_mem_tmp'] + node.meta['bwd_mem_out']
+ runtime_mem = runtime_mem + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"]
act_peak_mem = max(runtime_mem, act_peak_mem)
- runtime_mem = runtime_mem - \
- node.meta['bwd_mem_tmp'] - calculate_fwd_tmp(node)
+ runtime_mem = runtime_mem - node.meta["bwd_mem_tmp"] - calculate_fwd_tmp(node)
# free bwd_mem_out
bwd_deps[node] = len(node.all_input_nodes)
@@ -72,7 +68,7 @@ def compute_act_peak_mem(region_list: List[Region]) -> float:
if user_node in bwd_deps:
bwd_deps[user_node] -= 1
if bwd_deps[user_node] <= 0:
- runtime_mem -= user_node.meta['bwd_mem_out']
+ runtime_mem -= user_node.meta["bwd_mem_out"]
return act_peak_mem
@@ -86,13 +82,15 @@ def compute_total_param_mem(region_list: List[Region]) -> float:
def requires_upload_p_in_fwd(shared_reg: Region):
- return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid
- and shared_reg.need_offload)
+ return (shared_reg.r_id >= shared_reg.shared_rid) or (
+ shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload
+ )
def requires_release_p_in_bwd(shared_reg: Region):
- return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid
- and shared_reg.need_offload)
+ return (shared_reg.r_id >= shared_reg.shared_rid) or (
+ shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload
+ )
def requires_offload_g_in_bwd(region: Region):
diff --git a/colossalai/auto_parallel/passes/comm_metainfo_pass.py b/colossalai/auto_parallel/passes/comm_metainfo_pass.py
index ffda58e0689f1b0e7fa962dd328eee7453e559ec..ba290ee839d8bbe053c8af32a7935387c6618c6a 100644
--- a/colossalai/auto_parallel/passes/comm_metainfo_pass.py
+++ b/colossalai/auto_parallel/passes/comm_metainfo_pass.py
@@ -14,18 +14,20 @@ from colossalai.tensor.sharding_spec import ShardingSpec
shape_consistency_manager = ShapeConsistencyManager()
-def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
- target_sharding_spec: ShardingSpec) -> ShardMetaInfo:
+def _construct_shard_meta_info(
+ node: Node, origin_sharding_spec: ShardingSpec, target_sharding_spec: ShardingSpec
+) -> ShardMetaInfo:
# get comm_action_sequence and total_cost from shape_consistency_manager
_, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
- origin_sharding_spec, target_sharding_spec)
+ origin_sharding_spec, target_sharding_spec
+ )
meta_info = ShardMetaInfo()
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
# get mem cost for ShardMetaInfo
mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence)
# extract user that has _meta_data and extract element length
- input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data'))
+ input_node = next(n for n in node._input_nodes if hasattr(n, "_meta_data"))
element_length = input_node._meta_data.element_size()
mem_cost.fwd.activation *= element_length
@@ -37,9 +39,11 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
meta_info.memory_cost = mem_cost
# get computation cost for ShardMetaInfo
- meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
- total_cost['backward'] * element_length,
- total_cost['total'] * element_length)
+ meta_info.compute_cost = TrainCycleItem(
+ total_cost["forward"] * element_length,
+ total_cost["backward"] * element_length,
+ total_cost["total"] * element_length,
+ )
# get tensor shape for ShardMetaInfo
origin_sharding_spec: ShardingSpec
@@ -47,9 +51,9 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
input_shape = origin_sharding_spec.get_sharded_shape_per_device()
output_shape = target_sharding_spec.get_sharded_shape_per_device()
- meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
+ meta_info.fwd_in = [torch.rand(input_shape, device="meta")]
meta_info.fwd_buffer = []
- meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
+ meta_info.fwd_out = [torch.rand(output_shape, device="meta")]
return meta_info
@@ -62,8 +66,10 @@ def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -
# extract node index and user node index
args = node.args
node_index, user_node_index = args[3], args[4]
- origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][
- user_node_index]
+ origin_sharding_spec, target_sharding_spec = (
+ origin_spec_dict[node_index],
+ sharding_spec_dict[node_index][user_node_index],
+ )
return _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec)
@@ -77,37 +83,42 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> S
# this case is for all_reduce, there will be no memory cost
meta_info = ShardMetaInfo()
meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost)
- output_node = next(n for n in node.users if hasattr(n, '_meta_data'))
+ output_node = next(n for n in node.users if hasattr(n, "_meta_data"))
element_length = output_node._meta_data.element_size()
total_cost = comm_action.comm_spec.get_comm_cost()
- meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length,
- total_cost['backward'] * element_length,
- total_cost['total'] * element_length)
+ meta_info.compute_cost = TrainCycleItem(
+ total_cost["forward"] * element_length,
+ total_cost["backward"] * element_length,
+ total_cost["total"] * element_length,
+ )
input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device()
- meta_info.fwd_in = [torch.rand(input_shape, device='meta')]
+ meta_info.fwd_in = [torch.rand(input_shape, device="meta")]
meta_info.fwd_buffer = []
- meta_info.fwd_out = [torch.rand(output_shape, device='meta')]
+ meta_info.fwd_out = [torch.rand(output_shape, device="meta")]
else:
# this case will be handled by shape consistency manager
- origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[
- 'tgt_spec']
+ origin_sharding_spec, target_sharding_spec = (
+ comm_action.comm_spec["src_spec"],
+ comm_action.comm_spec["tgt_spec"],
+ )
meta_info = _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec)
return meta_info
-def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict,
- comm_actions_dict: Dict) -> GraphModule:
+def comm_metainfo_pass(
+ gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict, comm_actions_dict: Dict
+) -> GraphModule:
"""
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
"""
for node in gm.graph.nodes:
if node.target == runtime_apply:
- setattr(node, 'best_strategy_info', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
+ setattr(node, "best_strategy_info", _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
elif node.target == runtime_comm_spec_apply:
- setattr(node, 'best_strategy_info', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
+ setattr(node, "best_strategy_info", _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
else:
pass
return gm
diff --git a/colossalai/auto_parallel/passes/meta_info_prop.py b/colossalai/auto_parallel/passes/meta_info_prop.py
index bc0960483980b9242af531d4b309a9c74f735c16..9b000549de6ca8f80de4955353f42920de31fc64 100644
--- a/colossalai/auto_parallel/passes/meta_info_prop.py
+++ b/colossalai/auto_parallel/passes/meta_info_prop.py
@@ -21,16 +21,15 @@ def _normalize_tuple(x):
@compatibility(is_backward_compatible=False)
class MetaInfoProp:
-
def __init__(self, module: GraphModule) -> None:
self.module = module
self.func_dict = {
- 'placeholder': self.placeholder_handler,
- 'get_attr': self.get_attr_handler,
- 'output': self.output_handler,
- 'call_function': self.node_handler,
- 'call_module': self.node_handler,
- 'call_method': self.node_handler,
+ "placeholder": self.placeholder_handler,
+ "get_attr": self.get_attr_handler,
+ "output": self.output_handler,
+ "call_function": self.node_handler,
+ "call_module": self.node_handler,
+ "call_method": self.node_handler,
}
def _set_data_ptr(self, x):
@@ -46,7 +45,7 @@ class MetaInfoProp:
"""
Check if the node is inplace operation.
"""
- if node.op == 'call_module':
+ if node.op == "call_module":
return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD
elif node.op == "call_function":
return node.target in OUTPUT_SAVED_OPS
@@ -66,7 +65,7 @@ class MetaInfoProp:
Handle the placeholder node.
"""
graph_info = GraphInfo()
- out = _normalize_tuple(getattr(node, '_meta_data', None))
+ out = _normalize_tuple(getattr(node, "_meta_data", None))
graph_info.fwd_out = list(out) if out[0] is not None else []
node.meta = {**asdict(graph_info)}
@@ -96,7 +95,7 @@ class MetaInfoProp:
"""
Handle other kind of nodes
"""
- assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}"
+ assert hasattr(node, "best_strategy_info"), f"Cannot find best_strategy_info in node {node}, {node.op}"
graph_info = GraphInfo()
meta_info = node.best_strategy_info
meta_info: ShardMetaInfo
@@ -126,7 +125,8 @@ class MetaInfoProp:
for tensor in par.meta.get("fwd_out", []):
tensor: torch.Tensor
target_input_tensor = next(
- (x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None)
+ (x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None
+ )
if target_input_tensor is not None:
target_input_tensor.data_ptr = tensor.data_ptr
@@ -148,7 +148,7 @@ class MetaInfoProp:
graph_info.fwd_tmp = buffer_tensors
graph_info.fwd_out = output_tensors
- # fetch other memory informations
+ # fetch other memory information
memory_cost = meta_info.memory_cost
graph_info.fwd_mem_tmp = memory_cost.fwd.temp
graph_info.fwd_mem_out = memory_cost.fwd.activation
diff --git a/colossalai/auto_parallel/passes/runtime_apply_pass.py b/colossalai/auto_parallel/passes/runtime_apply_pass.py
index a473bb6e973de453f47c7bc16c2e83b8e7fe86df..27afe72c0db84f506aab3ba544e71937703b2a4c 100644
--- a/colossalai/auto_parallel/passes/runtime_apply_pass.py
+++ b/colossalai/auto_parallel/passes/runtime_apply_pass.py
@@ -1,18 +1,10 @@
-from copy import deepcopy
from typing import Dict, List
import torch
from torch.fx.node import Node
from colossalai._analyzer.fx.node_util import MetaInfo
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- CommAction,
- CommType,
- OperationData,
- OperationDataType,
- TrainCycleItem,
-)
-from colossalai.device.device_mesh import DeviceMesh
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommType, OperationDataType
from colossalai.tensor.comm_spec import CommSpec
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
@@ -30,19 +22,22 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i
return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec)
-def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int,
- user_node_index: int):
+def runtime_apply_for_iterable_object(
+ node: Node, origin_dict: Dict, input_dict: Dict, node_index: int, user_node_index: int
+):
"""
This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list
is converted into the user node expected form.
"""
rst = []
- for index, (origin_sharding_spec,
- target_sharding_spec) in enumerate(zip(origin_dict[node_index],
- input_dict[node_index][user_node_index])):
+ for index, (origin_sharding_spec, target_sharding_spec) in enumerate(
+ zip(origin_dict[node_index], input_dict[node_index][user_node_index])
+ ):
rst.append(
- shape_consistency_manager.apply_for_autoparallel_runtime(node[index], origin_sharding_spec,
- target_sharding_spec))
+ shape_consistency_manager.apply_for_autoparallel_runtime(
+ node[index], origin_sharding_spec, target_sharding_spec
+ )
+ )
rst = type(node)(rst)
return rst
@@ -55,8 +50,8 @@ def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_
if isinstance(comm_action.comm_spec, CommSpec):
rst = comm_action.comm_spec.covert_spec_to_action(tensor)
else:
- origin_sharding_spec = comm_action.comm_spec['src_spec']
- tgt_sharding_spec = comm_action.comm_spec['tgt_spec']
+ origin_sharding_spec = comm_action.comm_spec["src_spec"]
+ tgt_sharding_spec = comm_action.comm_spec["tgt_spec"]
rst = shape_consistency_manager.apply_for_autoparallel_runtime(tensor, origin_sharding_spec, tgt_sharding_spec)
return rst
@@ -70,16 +65,16 @@ def _preprocess_graph(nodes: List[Node]):
node_to_index_dict = {}
index = 0
for node in nodes:
- if node.target == 'sharding_spec_convert_dict':
+ if node.target == "sharding_spec_convert_dict":
input_dict_node = node
continue
- if node.target == 'origin_node_sharding_spec_dict':
+ if node.target == "origin_node_sharding_spec_dict":
origin_dict_node = node
continue
- if node.target == 'comm_actions_dict':
+ if node.target == "comm_actions_dict":
comm_actions_dict_node = node
continue
- if not hasattr(node, 'best_strategy'):
+ if not hasattr(node, "best_strategy"):
continue
node_to_index_dict[node] = index
index += 1
@@ -97,41 +92,46 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
input_dict_node, origin_dict_node, _, node_to_index_dict = _preprocess_graph(nodes)
for node in nodes:
- if not hasattr(node, 'best_strategy') or node.op == 'output':
+ if not hasattr(node, "best_strategy") or node.op == "output":
continue
for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes):
if isinstance(node.sharding_spec, (list, tuple)):
assert isinstance(
- node.target_sharding_specs,
- (list,
- tuple)), 'target sharding specs should be tuple or list when node.sharding_spec is tuple or list'
+ node.target_sharding_specs, (list, tuple)
+ ), "target sharding specs should be tuple or list when node.sharding_spec is tuple or list"
total_difference = 0
- for sharding_spec, target_sharding_spec in zip(node.sharding_spec,
- node.target_sharding_specs[user_node_index]):
+ for sharding_spec, target_sharding_spec in zip(
+ node.sharding_spec, node.target_sharding_specs[user_node_index]
+ ):
total_difference += sharding_spec.sharding_sequence_difference(target_sharding_spec)
if total_difference == 0:
continue
with mod_graph.inserting_before(user_node):
- shape_consistency_node = mod_graph.create_node('call_function',
- runtime_apply_for_iterable_object,
- args=(node, origin_dict_node, input_dict_node,
- node_to_index_dict[node], user_node_index))
+ shape_consistency_node = mod_graph.create_node(
+ "call_function",
+ runtime_apply_for_iterable_object,
+ args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index),
+ )
else:
- assert isinstance(node.sharding_spec,
- ShardingSpec), 'node.sharding_spec should be type of ShardingSpec, tuple or list.'
+ assert isinstance(
+ node.sharding_spec, ShardingSpec
+ ), "node.sharding_spec should be type of ShardingSpec, tuple or list."
if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0:
continue
with mod_graph.inserting_before(user_node):
- shape_consistency_node = mod_graph.create_node('call_function',
- runtime_apply,
- args=(node, origin_dict_node, input_dict_node,
- node_to_index_dict[node], user_node_index))
- if hasattr(user_node.meta['info'], 'activation_checkpoint'):
- MetaInfo(shape_consistency_node,
- mod_dir=user_node.meta['info'].mod_dir,
- activation_checkpoint=tuple(user_node.meta['info'].activation_checkpoint))
+ shape_consistency_node = mod_graph.create_node(
+ "call_function",
+ runtime_apply,
+ args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index),
+ )
+ if hasattr(user_node.meta["info"], "activation_checkpoint"):
+ MetaInfo(
+ shape_consistency_node,
+ mod_dir=user_node.meta["info"].mod_dir,
+ activation_checkpoint=tuple(user_node.meta["info"].activation_checkpoint),
+ )
new_args = list(user_node.args)
new_kwargs = dict(user_node.kwargs)
# the origin node may be a positional argument or key word argument of user node
@@ -158,12 +158,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
_, _, comm_actions_dict_node, node_to_index_dict = _preprocess_graph(nodes)
for node in nodes:
- if not hasattr(node, 'best_strategy') or node.op == 'output':
+ if not hasattr(node, "best_strategy") or node.op == "output":
continue
comm_actions = node.best_strategy.communication_actions
for op_data, comm_action in comm_actions.items():
-
if comm_action.comm_type == CommType.HOOK:
continue
if comm_action.comm_type == CommType.BEFORE:
@@ -174,10 +173,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
else:
comm_object = node.args[comm_action.arg_index]
with mod_graph.inserting_before(node):
- comm_spec_apply_node = mod_graph.create_node('call_function',
- runtime_comm_spec_apply,
- args=(comm_object, comm_actions_dict_node,
- node_to_index_dict[node], op_data.name))
+ comm_spec_apply_node = mod_graph.create_node(
+ "call_function",
+ runtime_comm_spec_apply,
+ args=(comm_object, comm_actions_dict_node, node_to_index_dict[node], op_data.name),
+ )
# the origin node may be a positional argument or key word argument of user node
if comm_action.key_for_kwarg is not None:
# substitute the origin node with comm_spec_apply_node
@@ -192,10 +192,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
elif comm_action.comm_type == CommType.AFTER:
with mod_graph.inserting_after(node):
- comm_spec_apply_node = mod_graph.create_node('call_function',
- runtime_comm_spec_apply,
- args=(node, comm_actions_dict_node,
- node_to_index_dict[node], op_data.name))
+ comm_spec_apply_node = mod_graph.create_node(
+ "call_function",
+ runtime_comm_spec_apply,
+ args=(node, comm_actions_dict_node, node_to_index_dict[node], op_data.name),
+ )
user_list = list(node.users.keys())
for user in user_list:
if user == comm_spec_apply_node:
@@ -211,15 +212,17 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node
new_kwargs[str(node)] = comm_spec_apply_node
user.kwargs = new_kwargs
- if hasattr(node.meta['info'], 'activation_checkpoint'):
- MetaInfo(comm_spec_apply_node,
- mod_dir=node.meta['info'].mod_dir,
- activation_checkpoint=tuple(node.meta['info'].activation_checkpoint))
+ if hasattr(node.meta["info"], "activation_checkpoint"):
+ MetaInfo(
+ comm_spec_apply_node,
+ mod_dir=node.meta["info"].mod_dir,
+ activation_checkpoint=tuple(node.meta["info"].activation_checkpoint),
+ )
return gm
-def _act_annotataion_pass(gm: torch.fx.GraphModule):
+def _act_annotation_pass(gm: torch.fx.GraphModule):
"""
This pass is used to add the act annotation to the new inserted nodes.
"""
@@ -227,21 +230,21 @@ def _act_annotataion_pass(gm: torch.fx.GraphModule):
nodes = tuple(mod_graph.nodes)
for node in nodes:
- if not hasattr(node.meta, 'activation_checkpoint'):
- from .runtime_preparation_pass import size_processing
+ if not hasattr(node.meta, "activation_checkpoint"):
+ pass
user_act_annotation = -1
input_act_annotation = -1
for user_node in node.users.keys():
- if 'activation_checkpoint' in user_node.meta:
- user_act_annotation = user_node.meta['activation_checkpoint']
+ if "activation_checkpoint" in user_node.meta:
+ user_act_annotation = user_node.meta["activation_checkpoint"]
break
for input_node in node._input_nodes.keys():
- if 'activation_checkpoint' in input_node.meta:
- input_act_annotation = input_node.meta['activation_checkpoint']
+ if "activation_checkpoint" in input_node.meta:
+ input_act_annotation = input_node.meta["activation_checkpoint"]
break
if user_act_annotation == input_act_annotation and user_act_annotation != -1:
- node.meta['activation_checkpoint'] = user_act_annotation
+ node.meta["activation_checkpoint"] = user_act_annotation
return gm
diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py
index 08af846b221db60b3950a0cf285238f616b17711..65c3d8e0cbeb0cca8b0100d407aba7d3b3443333 100644
--- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py
+++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py
@@ -1,19 +1,12 @@
import operator
-from copy import deepcopy
from typing import Dict, List, Union
import torch
-from torch.fx import symbolic_trace
from torch.fx.node import Node
from colossalai._analyzer.fx.node_util import MetaInfo
from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- CommAction,
- CommType,
- OperationDataType,
- ShardingStrategy,
-)
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommType, OperationDataType
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.comm_spec import _all_reduce
@@ -25,11 +18,13 @@ from .constants import SHAPE_ARGUMENT_OPS
shape_consistency_manager = ShapeConsistencyManager()
-def size_processing(size: Union[int, torch.Size],
- dim_partition_dict: Dict[int, List[int]],
- device_mesh_info: Dict[int, int],
- target_dim: int = None,
- node_name: str = None):
+def size_processing(
+ size: Union[int, torch.Size],
+ dim_partition_dict: Dict[int, List[int]],
+ device_mesh_info: Dict[int, int],
+ target_dim: int = None,
+ node_name: str = None,
+):
"""
This method will be invoked during runtime to convert size node value depending on distributed information.
"""
@@ -54,8 +49,9 @@ def size_processing(size: Union[int, torch.Size],
return size
-def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
- strategies_constructor: StrategiesConstructor):
+def solution_annotation_pass(
+ gm: torch.fx.GraphModule, solution: List[int], strategies_constructor: StrategiesConstructor
+):
"""
This method is used to stick the solution strategy to the nodes and add the information
required in runtime into graph as placeholder nodes.
@@ -70,14 +66,15 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)):
strategies_vector = node.strategies_vector
# stick the solution strategy to the corresponding node
- setattr(node, 'best_strategy', strategies_vector[strategy_index])
- setattr(node, 'sharding_spec', strategies_vector[strategy_index].get_sharding_spec_by_name(str(node)))
+ setattr(node, "best_strategy", strategies_vector[strategy_index])
+ setattr(node, "sharding_spec", strategies_vector[strategy_index].get_sharding_spec_by_name(str(node)))
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
- str(node))
+ str(node)
+ )
# attach the corresponding metainfo if node has the attribute `strategies_info`
- if hasattr(node, 'strategies_info'):
- setattr(node, 'best_strategy_info', node.strategies_info[strategy_index])
+ if hasattr(node, "strategies_info"):
+ setattr(node, "best_strategy_info", node.strategies_info[strategy_index])
# the dict to get input sharding specs of user node
sharding_spec_convert_dict = {}
@@ -92,15 +89,15 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
target_sharding_specs.append(target_sharding_spec)
sharding_spec_convert_dict[index] = target_sharding_specs
- setattr(node, 'target_sharding_specs', target_sharding_specs)
+ setattr(node, "target_sharding_specs", target_sharding_specs)
# the get_attr node strategy is kind of pending strategy, which means we will change it
# to the same strategy of the user node.
- if node.op == 'get_attr':
- assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.'
+ if node.op == "get_attr":
+ assert len(target_sharding_specs) == 1, f"sharing weight is not supported in current version."
target_node = node.strategies_vector.successor_nodes[0]
node_name = str(node)
- if target_node.op == 'call_function' and target_node.target in RESHAPE_FUNC_OP:
+ if target_node.op == "call_function" and target_node.target in RESHAPE_FUNC_OP:
node_name = str(target_node)
target_node = target_node.strategies_vector.successor_nodes[0]
user_strategy = target_node.best_strategy
@@ -122,11 +119,11 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
# add above dicts into graph
for node in nodes:
- if node.op != 'placeholder':
+ if node.op != "placeholder":
with mod_graph.inserting_before(node):
- input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict')
- origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict')
- comm_actions_dict_node = mod_graph.create_node('placeholder', target='comm_actions_dict')
+ input_specs_node = mod_graph.create_node("placeholder", target="sharding_spec_convert_dict")
+ origin_specs_node = mod_graph.create_node("placeholder", target="origin_node_sharding_spec_dict")
+ comm_actions_dict_node = mod_graph.create_node("placeholder", target="comm_actions_dict")
break
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
@@ -144,11 +141,11 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
# DeviceMesh information instructs the scaling of the size value
device_mesh_info = {}
- for dim, dim_size in enumerate(device_mesh.mesh_shape):
+ for dim, dim_size in enumerate(device_mesh.shape):
device_mesh_info[dim] = dim_size
def _extract_target_dim(node):
- '''
+ """
A helper function to extract the target dimension from size node.
There are two usages of torch.Tensor.size:
1. tensor.size()
@@ -156,7 +153,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
If a target_dim is assigned, then the output will be in type of int, instead of torch.Size.
Otherwise, the output will be in type of torch.Size and this function will return None.
- '''
+ """
target_dim = None
if len(node.args) > 1:
target_dim = node.args[1]
@@ -165,19 +162,21 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
return target_dim
def _post_processing(node, size_processing_node):
- '''
+ """
This function is used to process the dependency between the size node and its users after
inserting the size_process_node.
- '''
- # store original node and processing node pair in node_pairs dictioanry
+ """
+ # store original node and processing node pair in node_pairs dictionary
# It will be used to replace the original node with processing node in slice object
node_pairs[node] = size_processing_node
size_processing_node._meta_data = node._meta_data
- if hasattr(node.meta['info'], 'activation_checkpoint'):
- MetaInfo(size_processing_node,
- mod_dir=node.meta['info'].mod_dir,
- activation_checkpoint=tuple(node.meta['info'].activation_checkpoint))
+ if hasattr(node.meta["info"], "activation_checkpoint"):
+ MetaInfo(
+ size_processing_node,
+ mod_dir=node.meta["info"].mod_dir,
+ activation_checkpoint=tuple(node.meta["info"].activation_checkpoint),
+ )
user_list = list(node.users.keys())
for user in user_list:
@@ -196,10 +195,10 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
user.kwargs = new_kwargs
def _update_slice_object_args(slice_object):
- '''
+ """
This function is used to update the slice object argument list.
If the slice object contains the Node argument, then the size node will be replaced with
- '''
+ """
if isinstance(slice_object, slice):
start = slice_object.start
stop = slice_object.stop
@@ -220,8 +219,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
raise RuntimeError(f"Unsupported slice object type: {type(slice_object)}")
for node in nodes:
-
- if node.op == 'call_method' and node.target == 'size':
+ if node.op == "call_method" and node.target == "size":
# extract useful information from size node
# dim_partition_dict will instruct the size value on which
# dimension should be enlarged.
@@ -232,14 +230,14 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
# insert size_processing node
with mod_graph.inserting_after(node):
- size_processing_node = mod_graph.create_node('call_function',
- size_processing,
- args=(node, dim_partition_dict, device_mesh_info,
- target_dim, node.name))
+ size_processing_node = mod_graph.create_node(
+ "call_function",
+ size_processing,
+ args=(node, dim_partition_dict, device_mesh_info, target_dim, node.name),
+ )
_post_processing(node, size_processing_node)
- if node.op == 'call_function' and node.target == operator.getitem:
-
+ if node.op == "call_function" and node.target == operator.getitem:
getitem_index = node.args[1]
# slice object is quite special in torch.fx graph,
# On one side, we treat slice object same as type of int,
@@ -287,18 +285,19 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
nodes = tuple(mod_graph.nodes)
def _extract_info_from_sharding_spec(sharding_spec):
- '''
+ """
This function is used to extract the dim_partition_dict and device_mesh from
sharding spec instance or a list of sharding spec.
- '''
+ """
if isinstance(sharding_spec, ShardingSpec):
dim_partition_dict = sharding_spec.dim_partition_dict
device_mesh = sharding_spec.device_mesh
return dim_partition_dict, device_mesh
if sharding_spec is None:
return None, None
- assert isinstance(sharding_spec,
- (tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None'
+ assert isinstance(
+ sharding_spec, (tuple, list)
+ ), "sharding_spec should be type of ShardingSpec, tuple, list or None"
device_mesh = sharding_spec[0].device_mesh
dim_partition_dict = []
@@ -322,8 +321,9 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
else:
new_args.append(arg)
else:
- assert isinstance(arg,
- (int, tuple, list)), 'The argument in view node should be either type of Node or int.'
+ assert isinstance(
+ arg, (int, tuple, list)
+ ), "The argument in view node should be either type of Node or int."
if isinstance(arg, (tuple, list)):
new_args.extend(arg)
else:
@@ -332,7 +332,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node):
new_args = _process_node_arguments(node)
- if node.op == 'call_method':
+ if node.op == "call_method":
args_to_process = list(new_args[1:])
else:
args_to_process = list(new_args)
@@ -350,7 +350,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
args_to_process = tuple(args_to_process)
- if node.op == 'call_method':
+ if node.op == "call_method":
new_args = (new_args[0],) + args_to_process
else:
new_args = args_to_process
@@ -358,9 +358,9 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
node.args = new_args
def _filter_node_with_shape_args(node):
- if node.op == 'call_method':
+ if node.op == "call_method":
target = getattr(node.args[0]._meta_data.__class__, node.target)
- elif node.op == 'call_function':
+ elif node.op == "call_function":
target = node.target
else:
target = None
@@ -371,7 +371,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
for node in nodes:
# skip the placeholder node added in _solution_annotation pass
- if not hasattr(node, 'sharding_spec'):
+ if not hasattr(node, "sharding_spec"):
continue
output_dim_partition_dict, device_mesh = _extract_info_from_sharding_spec(node.sharding_spec)
@@ -388,19 +388,25 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
"""
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
- # This stream is created for overlaping the communication and computation.
+ # This stream is created for overlapping the communication and computation.
reduction_stream = torch.cuda.Stream()
def _add_hook_for_grad_communication(node, param, name=None):
-
comm_actions = node.best_strategy.communication_actions
def _filter_param_to_hook(node, op_data, comm_action, name):
-
- if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == name and comm_action.comm_type == CommType.HOOK:
+ if (
+ node.op == "call_module"
+ and op_data.type == OperationDataType.PARAM
+ and op_data.name == name
+ and comm_action.comm_type == CommType.HOOK
+ ):
return True
- if node.op == 'get_attr' and isinstance(
- node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
+ if (
+ node.op == "get_attr"
+ and isinstance(node._meta_data, torch.nn.parameter.Parameter)
+ and comm_action.comm_type == CommType.HOOK
+ ):
return True
return False
@@ -410,7 +416,6 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
if _filter_param_to_hook(node, operation_data, comm_action, name=name):
def wrapper(param, comm_spec, stream, overlap):
-
def hook_fn(grad):
if overlap:
with torch.cuda.stream(stream):
@@ -426,22 +431,26 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
# apply the sharding spec of parameters
if target_sharding_spec.dim_partition_dict != {}:
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
- setattr(param, 'sharding_spec', origin_sharding_spec)
+ setattr(param, "sharding_spec", origin_sharding_spec)
# TODO: build a ColoParameter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
param = torch.nn.Parameter(
- shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
- target_sharding_spec).detach().clone())
+ shape_consistency_manager.apply_for_autoparallel_runtime(
+ param.data, param.sharding_spec, target_sharding_spec
+ )
+ .detach()
+ .clone()
+ )
return param
for node in nodes:
- if node.op == 'call_module':
+ if node.op == "call_module":
target_module = node.graph.owning_module.get_submodule(node.target)
# TODO: we need to do more actions to take care of the shared parameters.
- if hasattr(target_module, 'processed') and target_module.processed:
+ if hasattr(target_module, "processed") and target_module.processed:
continue
- setattr(target_module, 'processed', True)
+ setattr(target_module, "processed", True)
for name, param in target_module.named_parameters():
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
param = _shard_param(param, target_sharding_spec)
@@ -453,7 +462,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
# apply the sharding spec of buffers
for name, buffer in target_module.named_buffers():
origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {})
- setattr(buffer, 'sharding_spec', origin_sharding_spec)
+ setattr(buffer, "sharding_spec", origin_sharding_spec)
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
buffer_sharded = shape_consistency_manager.apply(buffer, target_sharding_spec)
sharded_buffer_dict[name] = buffer_sharded
@@ -461,7 +470,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
for name, buffer_sharded in sharded_buffer_dict.items():
setattr(target_module, name, buffer_sharded.detach().clone())
- if node.op == 'get_attr':
+ if node.op == "get_attr":
root = node.graph.owning_module
atoms = node.target.split(".")
attr_len = len(atoms)
@@ -488,16 +497,18 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
"""
replace the origin kernel into kernel with implicit communication inside.
"""
- pass
-def runtime_preparation_pass(gm: torch.fx.GraphModule,
- solution: List[int],
- device_mesh: DeviceMesh,
- strategies_constructor: StrategiesConstructor,
- overlap=False):
- gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotatation_pass(
- gm, solution, strategies_constructor)
+def runtime_preparation_pass(
+ gm: torch.fx.GraphModule,
+ solution: List[int],
+ device_mesh: DeviceMesh,
+ strategies_constructor: StrategiesConstructor,
+ overlap=False,
+):
+ gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotation_pass(
+ gm, solution, strategies_constructor
+ )
gm = size_value_converting_pass(gm, device_mesh)
gm = node_args_converting_pass(gm, device_mesh)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
diff --git a/colossalai/auto_parallel/tensor_shard/constants.py b/colossalai/auto_parallel/tensor_shard/constants.py
index 99c1249340602daee1a1314f102bc600eae6667d..e9c2c8664a61d04a25dd76da398c900a863bd828 100644
--- a/colossalai/auto_parallel/tensor_shard/constants.py
+++ b/colossalai/auto_parallel/tensor_shard/constants.py
@@ -3,9 +3,22 @@ import operator
import torch
__all__ = [
- 'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP',
- 'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP',
- 'EMBEDDING_MODULE_OP', 'LAYERNORM_MODULE_OP', 'ELEMENTWISE_METHOD_OP', 'RESHAPE_METHOD_OP', 'INFINITY_COST'
+ "ELEMENTWISE_MODULE_OP",
+ "ELEMENTWISE_FUNC_OP",
+ "RESHAPE_FUNC_OP",
+ "CONV_MODULE_OP",
+ "CONV_FUNC_OP",
+ "LINEAR_MODULE_OP",
+ "LINEAR_FUNC_OP",
+ "BATCHNORM_MODULE_OP",
+ "POOL_MODULE_OP",
+ "NON_PARAM_FUNC_OP",
+ "BCAST_FUNC_OP",
+ "EMBEDDING_MODULE_OP",
+ "LAYERNORM_MODULE_OP",
+ "ELEMENTWISE_METHOD_OP",
+ "RESHAPE_METHOD_OP",
+ "INFINITY_COST",
]
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
@@ -18,13 +31,13 @@ ELEMENTWISE_FUNC_OP = [
torch.nn.functional.relu,
torch.nn.functional.dropout,
# softmax should not be here
- torch.nn.functional.softmax
+ torch.nn.functional.softmax,
]
ELEMENTWISE_METHOD_OP = [
torch.Tensor.to,
torch.Tensor.type,
# TODO: contiguous maybe need some extra processes.
- torch.Tensor.contiguous
+ torch.Tensor.contiguous,
]
RESHAPE_FUNC_OP = [
torch.flatten,
@@ -42,15 +55,36 @@ RESHAPE_METHOD_OP = [
torch.Tensor.transpose,
]
BCAST_FUNC_OP = [
- torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub,
- operator.mul, operator.floordiv, operator.truediv, torch.matmul, operator.pow, torch.pow
+ torch.add,
+ torch.sub,
+ torch.mul,
+ torch.div,
+ torch.floor_divide,
+ torch.true_divide,
+ operator.add,
+ operator.sub,
+ operator.mul,
+ operator.floordiv,
+ operator.truediv,
+ torch.matmul,
+ operator.pow,
+ torch.pow,
]
CONV_MODULE_OP = [
- torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
- torch.nn.ConvTranspose3d
+ torch.nn.Conv1d,
+ torch.nn.Conv2d,
+ torch.nn.Conv3d,
+ torch.nn.ConvTranspose1d,
+ torch.nn.ConvTranspose2d,
+ torch.nn.ConvTranspose3d,
]
CONV_FUNC_OP = [
- torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d
+ torch.conv1d,
+ torch.conv2d,
+ torch.conv3d,
+ torch.conv_transpose1d,
+ torch.conv_transpose2d,
+ torch.conv_transpose3d,
]
EMBEDDING_MODULE_OP = [torch.nn.modules.sparse.Embedding]
LINEAR_MODULE_OP = [torch.nn.Linear]
@@ -85,7 +119,7 @@ NON_PARAM_FUNC_OP = [
operator.floordiv,
operator.truediv,
# softmax should not be here
- torch.nn.functional.softmax
+ torch.nn.functional.softmax,
]
INFINITY_COST = 1e13
diff --git a/colossalai/auto_parallel/tensor_shard/initialize.py b/colossalai/auto_parallel/tensor_shard/initialize.py
index b406ca6fb7e0fd28a9a6d3e98365b093f73f7171..d82f0ef53f6605ccc3c0b8537dfe49d5526e6cc6 100644
--- a/colossalai/auto_parallel/tensor_shard/initialize.py
+++ b/colossalai/auto_parallel/tensor_shard/initialize.py
@@ -3,7 +3,6 @@ from typing import Dict, List, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
-from torch.fx import GraphModule
from torch.fx.graph import Graph
from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen
@@ -14,27 +13,32 @@ from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pas
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction
-from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
+from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
class ModuleWrapper(nn.Module):
- '''
+ """
This class is used to wrap the original module, and add the sharding_spec_dict, origin_spec_dict, comm_actions_dict
into the forward function.
- '''
-
- def __init__(self, module: ColoGraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]],
- origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]):
- '''
+ """
+
+ def __init__(
+ self,
+ module: ColoGraphModule,
+ sharding_spec_dict: Dict[int, List[ShardingSpec]],
+ origin_spec_dict: Dict[int, ShardingSpec],
+ comm_actions_dict: Dict[int, Dict[str, CommAction]],
+ ):
+ """
Args:
module: the original module
sharding_spec_dict: The sharding_spec_dict is used to record the target sharding specs of each tensor required in user node.
origin_spec_dict: The origin_spec_dict is used to record the original sharding spec of each tensor.
comm_actions_dict: The comm_actions_dict is used to record the communication actions of each tensor.
- '''
+ """
super(ModuleWrapper, self).__init__()
self.module = module
self.sharding_spec_dict = sharding_spec_dict
@@ -42,67 +46,68 @@ class ModuleWrapper(nn.Module):
self.comm_actions_dict = comm_actions_dict
def forward(self, *args, **kwargs):
- return self.module(*args,
- sharding_spec_convert_dict=self.sharding_spec_dict,
- origin_node_sharding_spec_dict=self.origin_spec_dict,
- comm_actions_dict=self.comm_actions_dict,
- **kwargs)
+ return self.module(
+ *args,
+ sharding_spec_convert_dict=self.sharding_spec_dict,
+ origin_node_sharding_spec_dict=self.origin_spec_dict,
+ comm_actions_dict=self.comm_actions_dict,
+ **kwargs,
+ )
def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader, data_process_func: callable):
- '''
+ """
This method is used to extract the meta_args from the dataloader under the instruction of the data_process_func.
- '''
+ """
# TODO: implement this function
- pass
def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]):
- '''
+ """
This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape
from the alpha_beta_dict. These two values will be used to estimate the communication cost.
- '''
+ """
# TODO: implement this function
- pass
-def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str,
- shard_option: str):
- '''
+def build_strategy_constructor(
+ graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str, shard_option: str
+):
+ """
This method is used to build the strategy_constructor for the given graph.
After this method, each node in the graph will have a strategies_vector which
is constructed by the related node handler.
- '''
- if solver_preference == 'standard':
+ """
+ if solver_preference == "standard":
solver_preference = SolverPerference.STANDARD
- elif solver_preference == 'tp':
+ elif solver_preference == "tp":
solver_preference = SolverPerference.TP
- elif solver_preference == 'dp':
+ elif solver_preference == "dp":
solver_preference = SolverPerference.DP
else:
- raise ValueError(f'Invalid solver_preference: {solver_preference}')
+ raise ValueError(f"Invalid solver_preference: {solver_preference}")
- if dataloader_option == 'replicated':
+ if dataloader_option == "replicated":
dataloader_option = DataloaderOption.REPLICATED
- elif dataloader_option == 'distributed':
+ elif dataloader_option == "distributed":
dataloader_option = DataloaderOption.DISTRIBUTED
else:
- raise ValueError(f'Invalid dataloader_option: {dataloader_option}')
+ raise ValueError(f"Invalid dataloader_option: {dataloader_option}")
- if shard_option == 'standard':
+ if shard_option == "standard":
shard_option = ShardOption.STANDARD
- elif shard_option == 'shard':
+ elif shard_option == "shard":
shard_option = ShardOption.SHARD
- elif shard_option == 'shard_last_axis':
+ elif shard_option == "shard_last_axis":
shard_option = ShardOption.SHARD_LAST_AXIS
- elif shard_option == 'full_shard':
+ elif shard_option == "full_shard":
shard_option = ShardOption.FULL_SHARD
else:
- raise ValueError(f'Invalid shard_option: {shard_option}')
+ raise ValueError(f"Invalid shard_option: {shard_option}")
- solver_options = SolverOptions(solver_perference=solver_preference,
- dataloader_option=dataloader_option,
- shard_option=shard_option)
+ solver_options = SolverOptions(
+ solver_perference=solver_preference, dataloader_option=dataloader_option, shard_option=shard_option
+ )
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
@@ -110,10 +115,10 @@ def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_pre
def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0):
- '''
+ """
This method is used to solve the best solution for the given graph.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
- '''
+ """
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
# graph_analyser = GraphAnalyser(gm)
@@ -127,23 +132,23 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
return solution
-def transform_to_sharded_model(gm: ColoGraphModule,
- meta_args: Dict,
- solution: List[int],
- device_mesh: DeviceMesh,
- strategies_constructor: StrategiesConstructor,
- overlap: bool = False):
- '''
+def transform_to_sharded_model(
+ gm: ColoGraphModule,
+ meta_args: Dict,
+ solution: List[int],
+ device_mesh: DeviceMesh,
+ strategies_constructor: StrategiesConstructor,
+ overlap: bool = False,
+):
+ """
This method is used to transform the original graph to the sharded graph.
The model parameters will be sharded according to the solution and the grad hooks
will be added to the sharded graph using the runtime_preparation_pass.
The communication node will be added into the graph using the runtime_apply_pass.
- '''
- gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm,
- solution,
- device_mesh,
- strategies_constructor,
- overlap=overlap)
+ """
+ gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
+ gm, solution, device_mesh, strategies_constructor, overlap=overlap
+ )
gm = runtime_apply_pass(gm)
shape_prop_pass(gm, *meta_args.values(), sharding_spec_dict, origin_spec_dict, comm_actions_dict)
gm.recompile()
@@ -152,12 +157,14 @@ def transform_to_sharded_model(gm: ColoGraphModule,
return gm, sharding_spec_dicts
-def initialize_device_mesh(world_size: int = -1,
- physical_devices: List[int] = None,
- alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
- logical_mesh_shape: Tuple[int] = None,
- logical_mesh_id: torch.Tensor = None):
- '''
+def initialize_device_mesh(
+ world_size: int = -1,
+ physical_devices: List[int] = None,
+ alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
+ logical_mesh_shape: Tuple[int] = None,
+ logical_mesh_id: torch.Tensor = None,
+):
+ """
This method is used to initialize the device mesh.
Args:
@@ -170,7 +177,7 @@ def initialize_device_mesh(world_size: int = -1,
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
mesh shape.
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
- '''
+ """
# if world_size is not set, use the world size from torch.distributed
if world_size == -1:
world_size = dist.get_world_size()
@@ -201,27 +208,31 @@ def initialize_device_mesh(world_size: int = -1,
# extract alpha and beta values for the chosen logical mesh shape
mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_id)
- device_mesh = DeviceMesh(physical_mesh_id=physical_mesh,
- logical_mesh_id=logical_mesh_id,
- mesh_alpha=mesh_alpha,
- mesh_beta=mesh_beta,
- init_process_group=True)
+ device_mesh = DeviceMesh(
+ physical_mesh_id=physical_mesh,
+ logical_mesh_id=logical_mesh_id,
+ mesh_alpha=mesh_alpha,
+ mesh_beta=mesh_beta,
+ init_process_group=True,
+ )
return device_mesh
-def initialize_model(model: nn.Module,
- meta_args: Dict[str, torch.Tensor],
- device_mesh: DeviceMesh,
- memory_budget: float = -1.0,
- overlap: bool = False,
- solver_preference: str = 'standard',
- dataloader_option: str = 'replicated',
- shard_option: str = 'standard',
- save_solver_solution: bool = False,
- load_solver_solution: bool = False,
- solution_path: str = None,
- return_solution: bool = False):
- '''
+def initialize_model(
+ model: nn.Module,
+ meta_args: Dict[str, torch.Tensor],
+ device_mesh: DeviceMesh,
+ memory_budget: float = -1.0,
+ overlap: bool = False,
+ solver_preference: str = "standard",
+ dataloader_option: str = "replicated",
+ shard_option: str = "standard",
+ save_solver_solution: bool = False,
+ load_solver_solution: bool = False,
+ solution_path: str = None,
+ return_solution: bool = False,
+):
+ """
This method is used to initialize the sharded model which could be used as normal pytorch model.
Args:
@@ -246,7 +257,7 @@ def initialize_model(model: nn.Module,
return_solution(optional): if the return_solution is True, the solution will be returned. The returned
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
return a series of integers, but return the best strategies.
- '''
+ """
tracer = ColoTracer(trace_act_ckpt=True, bias_addition_split=True)
graph = tracer.trace(root=model, meta_args=meta_args)
@@ -256,11 +267,13 @@ def initialize_model(model: nn.Module,
shape_prop_pass(gm, *meta_args.values())
gm.recompile()
- strategies_constructor = build_strategy_constructor(graph,
- device_mesh,
- solver_preference=solver_preference,
- dataloader_option=dataloader_option,
- shard_option=shard_option)
+ strategies_constructor = build_strategy_constructor(
+ graph,
+ device_mesh,
+ solver_preference=solver_preference,
+ dataloader_option=dataloader_option,
+ shard_option=shard_option,
+ )
if load_solver_solution:
solution = torch.load(solution_path)
else:
@@ -268,8 +281,9 @@ def initialize_model(model: nn.Module,
if save_solver_solution:
torch.save(solution, solution_path)
- gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_args, solution, device_mesh, strategies_constructor,
- overlap)
+ gm, sharding_spec_dicts = transform_to_sharded_model(
+ gm, meta_args, solution, device_mesh, strategies_constructor, overlap
+ )
model_to_return = ModuleWrapper(gm, *sharding_spec_dicts)
@@ -277,28 +291,30 @@ def initialize_model(model: nn.Module,
solution_to_return = []
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
for index, node in enumerate(nodes):
- solution_to_return.append(f'{node.name} {node.strategies_vector[solution[index]].name}')
+ solution_to_return.append(f"{node.name} {node.strategies_vector[solution[index]].name}")
return model_to_return, solution_to_return
else:
return model_to_return
-def autoparallelize(model: nn.Module,
- meta_args: Dict[str, torch.Tensor] = None,
- data_loader: torch.utils.data.DataLoader = None,
- data_process_func: callable = None,
- alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
- logical_mesh_shape: Tuple[int] = None,
- logical_mesh_id: torch.Tensor = None,
- solver_preference: str = 'standard',
- dataloader_option: str = 'replicated',
- shard_option: str = 'standard',
- save_solver_solution: bool = False,
- load_solver_solution: bool = False,
- solver_solution_path: str = None,
- return_solution: bool = False,
- memory_budget: float = -1.0):
- '''
+def autoparallelize(
+ model: nn.Module,
+ meta_args: Dict[str, torch.Tensor] = None,
+ data_loader: torch.utils.data.DataLoader = None,
+ data_process_func: callable = None,
+ alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
+ logical_mesh_shape: Tuple[int] = None,
+ logical_mesh_id: torch.Tensor = None,
+ solver_preference: str = "standard",
+ dataloader_option: str = "replicated",
+ shard_option: str = "standard",
+ save_solver_solution: bool = False,
+ load_solver_solution: bool = False,
+ solver_solution_path: str = None,
+ return_solution: bool = False,
+ memory_budget: float = -1.0,
+):
+ """
This method is used to initialize the device mesh, extract the meta_args, and
use them to create a sharded model.
@@ -329,24 +345,26 @@ def autoparallelize(model: nn.Module,
return_solution(optional): if the return_solution is True, the solution will be returned.
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
the memory budget will be infinity.
- '''
- device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict,
- logical_mesh_shape=logical_mesh_shape,
- logical_mesh_id=logical_mesh_id)
+ """
+ device_mesh = initialize_device_mesh(
+ alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape, logical_mesh_id=logical_mesh_id
+ )
if meta_args is None:
meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func)
- rst_to_unpack = initialize_model(model,
- meta_args,
- device_mesh,
- solver_preference=solver_preference,
- dataloader_option=dataloader_option,
- shard_option=shard_option,
- save_solver_solution=save_solver_solution,
- load_solver_solution=load_solver_solution,
- solution_path=solver_solution_path,
- return_solution=return_solution,
- memory_budget=memory_budget)
+ rst_to_unpack = initialize_model(
+ model,
+ meta_args,
+ device_mesh,
+ solver_preference=solver_preference,
+ dataloader_option=dataloader_option,
+ shard_option=shard_option,
+ save_solver_solution=save_solver_solution,
+ load_solver_solution=load_solver_solution,
+ solution_path=solver_solution_path,
+ return_solution=return_solution,
+ memory_budget=memory_budget,
+ )
if return_solution:
model, solution = rst_to_unpack
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
index 9903ca54e52cb70559cce2c68169c84ca08bef9c..aa2e5e9c40c0aa9dbe0b360f6907744ed348fa31 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
@@ -25,11 +25,33 @@ from .view_handler import ViewHandler
from .where_handler import WhereHandler
__all__ = [
- 'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
- 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
- 'UnaryElementwiseHandler', 'DefaultReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
- 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
- 'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler',
- 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'TransposeHandler',
- 'SplitHandler'
+ "LinearFunctionHandler",
+ "LinearModuleHandler",
+ "BMMFunctionHandler",
+ "AddBMMFunctionHandler",
+ "LayerNormModuleHandler",
+ "BatchNormModuleHandler",
+ "ConvModuleHandler",
+ "ConvFunctionHandler",
+ "UnaryElementwiseHandler",
+ "DefaultReshapeHandler",
+ "PlaceholderHandler",
+ "OutputHandler",
+ "WhereHandler",
+ "NormPoolingHandler",
+ "BinaryElementwiseHandler",
+ "MatMulHandler",
+ "operator_registry",
+ "ADDMMFunctionHandler",
+ "GetItemHandler",
+ "GetattrHandler",
+ "ViewHandler",
+ "PermuteHandler",
+ "TensorConstructorHandler",
+ "EmbeddingModuleHandler",
+ "EmbeddingFunctionHandler",
+ "SumHandler",
+ "SoftmaxHandler",
+ "TransposeHandler",
+ "SplitHandler",
]
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py
index da0d199c5e05b37340cb4ddcfee0b52a9102fadf..47c654d6aa436fb145420aaa666e8bba2eeaeaeb 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py
@@ -2,15 +2,13 @@ from typing import Dict, List, Union
import torch
-from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
-
-from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy
+from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator
-__all__ = ['ADDMMFunctionHandler']
+__all__ = ["ADDMMFunctionHandler"]
@operator_registry.register(torch.addmm)
@@ -30,25 +28,26 @@ class ADDMMFunctionHandler(NodeHandler):
return data_type
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
-
# input operand
input_data = self.node.args[1]._meta_data
- physical_input_operand = OperationData(name=str(self.node.args[1]),
- type=self._infer_op_data_type(input_data),
- data=input_data)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[1]), type=self._infer_op_data_type(input_data), data=input_data
+ )
# other operand
other_data = self.node.args[2]._meta_data
- physical_other_operand = OperationData(name=str(self.node.args[2]),
- type=self._infer_op_data_type(other_data),
- data=other_data)
+ physical_other_operand = OperationData(
+ name=str(self.node.args[2]), type=self._infer_op_data_type(other_data), data=other_data
+ )
# bias physical shape
bias_logical_shape = self.node._meta_data.shape
bias_data = self.node.args[0]._meta_data
- physical_bias_operand = OperationData(name=str(self.node.args[0]),
- type=self._infer_op_data_type(bias_data),
- data=bias_data,
- logical_shape=bias_logical_shape)
+ physical_bias_operand = OperationData(
+ name=str(self.node.args[0]),
+ type=self._infer_op_data_type(bias_data),
+ data=bias_data,
+ logical_shape=bias_logical_shape,
+ )
# output
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
@@ -57,7 +56,7 @@ class ADDMMFunctionHandler(NodeHandler):
"input": physical_input_operand,
"other": physical_other_operand,
"output": physical_output,
- 'bias': physical_bias_operand
+ "bias": physical_bias_operand,
}
return mapping
@@ -66,26 +65,27 @@ class ADDMMFunctionHandler(NodeHandler):
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(
- LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='addmm'))
+ LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type="addmm")
+ )
return generators
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
# convert bias from its logical sharding spec to its physical sharding spec
op_data_mapping = self.get_operation_data_mapping()
- bias_op_data = op_data_mapping['bias']
+ bias_op_data = op_data_mapping["bias"]
bias_physical_shape = bias_op_data.data.shape
bias_logical_shape = bias_op_data.logical_shape
bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name)
bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
- bias_sharding_spec, bias_logical_shape, bias_physical_shape)
+ bias_sharding_spec, bias_logical_shape, bias_physical_shape
+ )
strategy.sharding_specs[bias_op_data] = bias_sharding_spec
if len(removed_dims) > 0:
- comm_action = comm_actions_for_oprands(node=self.node,
- removed_dims=removed_dims,
- op_data=bias_op_data,
- sharding_spec=bias_sharding_spec)
+ comm_action = comm_actions_for_oprands(
+ node=self.node, removed_dims=removed_dims, op_data=bias_op_data, sharding_spec=bias_sharding_spec
+ )
strategy.communication_actions[bias_op_data] = comm_action
return strategy
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py
index cb1bb36b78796db3d5656213518376f8f365dce0..df4b1d6cef3fcc3473d692e208728e21f256588b 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py
@@ -2,12 +2,12 @@ from typing import Dict, List
import torch
-from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
-from .node_handler import MetaInfoModuleHandler, ModuleHandler
+from ..sharding_strategy import OperationData, OperationDataType
+from .node_handler import MetaInfoModuleHandler
from .registry import operator_registry
from .strategy import BatchNormStrategyGenerator, StrategyGenerator
-__all__ = ['BatchNormModuleHandler']
+__all__ = ["BatchNormModuleHandler"]
@operator_registry.register(torch.nn.BatchNorm1d)
@@ -27,30 +27,37 @@ class BatchNormModuleHandler(MetaInfoModuleHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data)
- physical_other_operand = OperationData(name="weight",
- type=OperationDataType.PARAM,
- data=self.named_parameters['weight'],
- logical_shape=self.named_parameters['weight'].shape)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
+ )
+ physical_other_operand = OperationData(
+ name="weight",
+ type=OperationDataType.PARAM,
+ data=self.named_parameters["weight"],
+ logical_shape=self.named_parameters["weight"].shape,
+ )
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
- physical_running_mean_operand = OperationData(name="running_mean",
- type=OperationDataType.BUFFER,
- data=self.named_buffers['running_mean'],
- logical_shape=self.named_buffers['running_mean'].shape)
+ physical_running_mean_operand = OperationData(
+ name="running_mean",
+ type=OperationDataType.BUFFER,
+ data=self.named_buffers["running_mean"],
+ logical_shape=self.named_buffers["running_mean"].shape,
+ )
- physical_running_var_operand = OperationData(name="running_var",
- type=OperationDataType.BUFFER,
- data=self.named_buffers['running_var'],
- logical_shape=self.named_buffers['running_var'].shape)
+ physical_running_var_operand = OperationData(
+ name="running_var",
+ type=OperationDataType.BUFFER,
+ data=self.named_buffers["running_var"],
+ logical_shape=self.named_buffers["running_var"].shape,
+ )
physical_num_batches_tracked_operand = OperationData(
name="num_batches_tracked",
type=OperationDataType.BUFFER,
- data=self.named_buffers['num_batches_tracked'],
- logical_shape=self.named_buffers['num_batches_tracked'].shape)
+ data=self.named_buffers["num_batches_tracked"],
+ logical_shape=self.named_buffers["num_batches_tracked"].shape,
+ )
mapping = {
"input": physical_input_operand,
@@ -58,12 +65,12 @@ class BatchNormModuleHandler(MetaInfoModuleHandler):
"output": physical_output,
"running_mean": physical_running_mean_operand,
"running_var": physical_running_var_operand,
- "num_batches_tracked": physical_num_batches_tracked_operand
+ "num_batches_tracked": physical_num_batches_tracked_operand,
}
- if self.named_parameters['bias'] is not None:
- physical_bias_operand = OperationData(name="bias",
- type=OperationDataType.PARAM,
- data=self.named_parameters['bias'])
- mapping['bias'] = physical_bias_operand
+ if self.named_parameters["bias"] is not None:
+ physical_bias_operand = OperationData(
+ name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"]
+ )
+ mapping["bias"] = physical_bias_operand
return mapping
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py
index db8f0b54ddeeb1c5250951f0c9e8bfef364eb16d..f8c137348353f760aee338b6a94e5df97dd90699 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py
@@ -4,15 +4,14 @@ import torch
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, ShardingStrategy
-from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from ..constants import BCAST_FUNC_OP
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
-from .node_handler import MetaInfoNodeHandler, NodeHandler
+from .node_handler import MetaInfoNodeHandler
from .registry import operator_registry
from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator
-__all__ = ['BinaryElementwiseHandler']
+__all__ = ["BinaryElementwiseHandler"]
@operator_registry.register(BCAST_FUNC_OP)
@@ -38,7 +37,7 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
# The meta_data of node type argument could also possibly be a non-tensor object.
if not isinstance(meta_data, torch.Tensor):
assert isinstance(meta_data, (int, float))
- meta_data = torch.Tensor([meta_data]).to('meta')
+ meta_data = torch.Tensor([meta_data]).to("meta")
non_tensor = True
else:
@@ -46,7 +45,7 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
# but we can deem it as meta data
# as it won't affect the strategy generation
assert isinstance(self.node.args[idx], (int, float))
- meta_data = torch.Tensor([self.node.args[idx]]).to('meta')
+ meta_data = torch.Tensor([self.node.args[idx]]).to("meta")
non_tensor = True
return meta_data, non_tensor
@@ -58,24 +57,27 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
# and filter the non-tensor op_data in post_process.
self.non_tensor_list = []
# assert False
- input_op_data = OperationData(name=str(self.node.args[0]),
- type=_get_op_data_type(input_meta_data),
- data=input_meta_data,
- logical_shape=bcast_shape)
- other_op_data = OperationData(name=str(self.node.args[1]),
- type=_get_op_data_type(other_meta_data),
- data=other_meta_data,
- logical_shape=bcast_shape)
- output_op_data = OperationData(name=str(self.node),
- type=OperationDataType.OUTPUT,
- data=output_meta_data,
- logical_shape=bcast_shape)
+ input_op_data = OperationData(
+ name=str(self.node.args[0]),
+ type=_get_op_data_type(input_meta_data),
+ data=input_meta_data,
+ logical_shape=bcast_shape,
+ )
+ other_op_data = OperationData(
+ name=str(self.node.args[1]),
+ type=_get_op_data_type(other_meta_data),
+ data=other_meta_data,
+ logical_shape=bcast_shape,
+ )
+ output_op_data = OperationData(
+ name=str(self.node), type=OperationDataType.OUTPUT, data=output_meta_data, logical_shape=bcast_shape
+ )
if non_tensor_input:
self.non_tensor_list.append(input_op_data)
if non_tensor_other:
self.non_tensor_list.append(other_op_data)
- mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
+ mapping = {"input": input_op_data, "other": other_op_data, "output": output_op_data}
return mapping
def get_strategy_generator(self) -> List[StrategyGenerator]:
@@ -100,14 +102,14 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
logical_shape = op_data.logical_shape
sharding_spec = strategy.get_sharding_spec_by_name(op_data.name)
sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
- sharding_spec, logical_shape, physical_shape)
+ sharding_spec, logical_shape, physical_shape
+ )
strategy.sharding_specs[op_data] = sharding_spec
if len(removed_dims) > 0:
- comm_action = comm_actions_for_oprands(node=self.node,
- removed_dims=removed_dims,
- op_data=op_data,
- sharding_spec=sharding_spec)
+ comm_action = comm_actions_for_oprands(
+ node=self.node, removed_dims=removed_dims, op_data=op_data, sharding_spec=sharding_spec
+ )
strategy.communication_actions[op_data] = comm_action
return strategy
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py
index da2b733c9f7afda075c10dc3dd17a0d4f42fbc01..5c22ac7bef117be918a1cfda03ae562e48741e19 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py
@@ -2,15 +2,13 @@ from typing import Dict, List, Union
import torch
-from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
-
-from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy
+from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator
-__all__ = ['BMMFunctionHandler', 'AddBMMFunctionHandler']
+__all__ = ["BMMFunctionHandler", "AddBMMFunctionHandler"]
def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None):
@@ -19,14 +17,14 @@ def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None):
node handler to reduce code redundancy.
"""
# input operand
- physical_input_operand = OperationData(name=str(node.args[input_idx]),
- type=OperationDataType.ARG,
- data=node.args[input_idx]._meta_data)
+ physical_input_operand = OperationData(
+ name=str(node.args[input_idx]), type=OperationDataType.ARG, data=node.args[input_idx]._meta_data
+ )
# other operand
- physical_other_operand = OperationData(name=str(node.args[other_idx]),
- type=OperationDataType.ARG,
- data=node.args[other_idx]._meta_data)
+ physical_other_operand = OperationData(
+ name=str(node.args[other_idx]), type=OperationDataType.ARG, data=node.args[other_idx]._meta_data
+ )
# output
physical_output = OperationData(name=str(node), type=OperationDataType.OUTPUT, data=node._meta_data)
@@ -35,11 +33,13 @@ def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None):
if bias_idx is not None:
# bias physical shape
bias_logical_shape = node._meta_data.shape
- physical_bias_operand = OperationData(name=str(node.args[bias_idx]),
- type=OperationDataType.ARG,
- data=node.args[bias_idx]._meta_data,
- logical_shape=bias_logical_shape)
- mapping['bias'] = physical_bias_operand
+ physical_bias_operand = OperationData(
+ name=str(node.args[bias_idx]),
+ type=OperationDataType.ARG,
+ data=node.args[bias_idx]._meta_data,
+ logical_shape=bias_logical_shape,
+ )
+ mapping["bias"] = physical_bias_operand
return mapping
@@ -91,20 +91,20 @@ class AddBMMFunctionHandler(NodeHandler):
# convert bias from its logical sharding spec to its physical sharding spec
op_data_mapping = self.get_operation_data_mapping()
- if 'bias' in op_data_mapping:
- bias_op_data = op_data_mapping['bias']
+ if "bias" in op_data_mapping:
+ bias_op_data = op_data_mapping["bias"]
bias_physical_shape = bias_op_data.data.shape
bias_logical_shape = bias_op_data.logical_shape
bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name)
bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
- bias_sharding_spec, bias_logical_shape, bias_physical_shape)
+ bias_sharding_spec, bias_logical_shape, bias_physical_shape
+ )
strategy.sharding_specs[bias_op_data] = bias_sharding_spec
if len(removed_dims) > 0:
- comm_action = comm_actions_for_oprands(node=self.node,
- removed_dims=removed_dims,
- op_data=bias_op_data,
- sharding_spec=bias_sharding_spec)
+ comm_action = comm_actions_for_oprands(
+ node=self.node, removed_dims=removed_dims, op_data=bias_op_data, sharding_spec=bias_sharding_spec
+ )
strategy.communication_actions[bias_op_data] = comm_action
return strategy
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py
index 272b1c85630a8ab15145e701740a44e20d5103b8..fd7c1f837a5a69e2bbf36330e6d0c65bbfdf059b 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py
@@ -3,13 +3,13 @@ from typing import Dict, List
import torch
import torch.nn.functional as F
-from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
+from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import transpose_partition_dim
-from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler
+from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler
from .registry import operator_registry
from .strategy import ConvStrategyGenerator, StrategyGenerator
-__all__ = ['ConvModuleHandler', 'ConvFunctionHandler']
+__all__ = ["ConvModuleHandler", "ConvFunctionHandler"]
@operator_registry.register(torch.nn.Conv1d)
@@ -29,25 +29,29 @@ class ConvModuleHandler(MetaInfoModuleHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
+ )
logical_shape_for_weight = list(self.named_parameters["weight"].shape)
- logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[
- 1], logical_shape_for_weight[0]
- physical_other_operand = OperationData(name="weight",
- type=OperationDataType.PARAM,
- data=self.named_parameters['weight'],
- logical_shape=torch.Size(logical_shape_for_weight))
+ logical_shape_for_weight[0], logical_shape_for_weight[1] = (
+ logical_shape_for_weight[1],
+ logical_shape_for_weight[0],
+ )
+ physical_other_operand = OperationData(
+ name="weight",
+ type=OperationDataType.PARAM,
+ data=self.named_parameters["weight"],
+ logical_shape=torch.Size(logical_shape_for_weight),
+ )
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if "bias" in self.named_parameters:
- physical_bias_operand = OperationData(name="bias",
- type=OperationDataType.PARAM,
- data=self.named_parameters['bias'])
- mapping['bias'] = physical_bias_operand
+ physical_bias_operand = OperationData(
+ name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"]
+ )
+ mapping["bias"] = physical_bias_operand
return mapping
def post_process(self, strategy: ShardingStrategy):
@@ -77,9 +81,9 @@ class ConvFunctionHandler(MetaInfoNodeHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
+ )
# check if the other operand is a parameter
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
@@ -88,26 +92,30 @@ class ConvFunctionHandler(MetaInfoNodeHandler):
data_type = OperationDataType.ARG
logical_shape_for_weight = list(self.node.args[1]._meta_data.shape)
- logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[
- 1], logical_shape_for_weight[0]
- physical_other_operand = OperationData(name=str(self.node.args[1]),
- type=data_type,
- data=self.node.args[1]._meta_data,
- logical_shape=torch.Size(logical_shape_for_weight))
+ logical_shape_for_weight[0], logical_shape_for_weight[1] = (
+ logical_shape_for_weight[1],
+ logical_shape_for_weight[0],
+ )
+ physical_other_operand = OperationData(
+ name=str(self.node.args[1]),
+ type=data_type,
+ data=self.node.args[1]._meta_data,
+ logical_shape=torch.Size(logical_shape_for_weight),
+ )
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
- if "bias" in self.node.kwargs and self.node.kwargs['bias'] is not None:
+ if "bias" in self.node.kwargs and self.node.kwargs["bias"] is not None:
# check if the other operand is a parameter
if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
- physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]),
- type=data_type,
- data=self.node.kwargs["bias"]._meta_data)
- mapping['bias'] = physical_bias_operand
+ physical_bias_operand = OperationData(
+ name=str(self.node.kwargs["bias"]), type=data_type, data=self.node.kwargs["bias"]._meta_data
+ )
+ mapping["bias"] = physical_bias_operand
return mapping
def post_process(self, strategy: ShardingStrategy):
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py
index 0c5b9f39e1fba75b44308d57569c3b8c0b5087c0..feb1032a6c0f2077f6a525523cde765b94bc8d1c 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py
@@ -3,11 +3,11 @@ from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
-from .node_handler import MetaInfoNodeHandler, NodeHandler
+from .node_handler import MetaInfoNodeHandler
from .registry import operator_registry
from .strategy import DefaultReshapeGenerator, StrategyGenerator
-__all__ = ['DefaultReshapeHandler']
+__all__ = ["DefaultReshapeHandler"]
@operator_registry.register(torch.flatten)
@@ -54,17 +54,15 @@ class DefaultReshapeHandler(MetaInfoNodeHandler):
input_data = self.node.args[0]._meta_data
input_logical_shape = self.infer_logical_shape(input_data)
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=data_type,
- data=input_data,
- logical_shape=input_logical_shape)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]), type=data_type, data=input_data, logical_shape=input_logical_shape
+ )
output_data = self.node._meta_data
output_logical_shape = self.infer_logical_shape(output_data)
- physical_output = OperationData(name=str(self.node),
- type=OperationDataType.OUTPUT,
- data=output_data,
- logical_shape=output_logical_shape)
+ physical_output = OperationData(
+ name=str(self.node), type=OperationDataType.OUTPUT, data=output_data, logical_shape=output_logical_shape
+ )
mapping = {"input": physical_input_operand, "output": physical_output}
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py
index e154105b672de30b5675ca56147fb6b68205a469..f29c3a0b7d5d28dfa9e96db36562702c25b613ea 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py
@@ -12,11 +12,12 @@ from .node_handler import ModuleHandler, NodeHandler
from .registry import operator_registry
from .strategy import EmbeddingStrategyGenerator, StrategyGenerator
-__all__ = ['EmbeddingModuleHandler', 'EmbeddingFunctionHandler']
+__all__ = ["EmbeddingModuleHandler", "EmbeddingFunctionHandler"]
-def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy: ShardingStrategy, input_name: str,
- output_name: str) -> List[ShardingStrategy]:
+def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(
+ strategy: ShardingStrategy, input_name: str, output_name: str
+) -> List[ShardingStrategy]:
"""
This function converts the logical sharding spec to the physical sharding spec for both the input and output
of the embedding operation.
@@ -56,27 +57,31 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy:
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
try:
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
- update_partition_dim(sharding_spec=input_sharding_spec,
- dim_mapping={0: i},
- physical_shape=input_op_data.data.shape,
- inplace=True)
+ update_partition_dim(
+ sharding_spec=input_sharding_spec,
+ dim_mapping={0: i},
+ physical_shape=input_op_data.data.shape,
+ inplace=True,
+ )
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
dim_mapping = {0: i, last_logical_output_dims: last_physical_output_dims}
else:
dim_mapping = {0: i}
- update_partition_dim(sharding_spec=output_sharding_spec,
- dim_mapping=dim_mapping,
- physical_shape=output_op_data.data.shape,
- inplace=True)
+ update_partition_dim(
+ sharding_spec=output_sharding_spec,
+ dim_mapping=dim_mapping,
+ physical_shape=output_op_data.data.shape,
+ inplace=True,
+ )
- strategy_copy.name = f'{strategy.name}_{i}'
+ strategy_copy.name = f"{strategy.name}_{i}"
sharding_strategies.append(strategy_copy)
except ShardingNotDivisibleError as e:
logger.debug(
- f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}'
+ f"Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}"
)
else:
# the generated sharding strategy does not shard the non-matrix dimension,
@@ -87,20 +92,21 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy:
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
# after updating, the logical shape will be replaced by the physical shape
- update_partition_dim(sharding_spec=input_sharding_spec,
- dim_mapping={},
- physical_shape=input_op_data.data.shape,
- inplace=True)
+ update_partition_dim(
+ sharding_spec=input_sharding_spec, dim_mapping={}, physical_shape=input_op_data.data.shape, inplace=True
+ )
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
dim_mapping = {last_logical_output_dims: last_physical_output_dims}
else:
dim_mapping = {}
- update_partition_dim(sharding_spec=output_sharding_spec,
- dim_mapping=dim_mapping,
- physical_shape=output_op_data.data.shape,
- inplace=True)
+ update_partition_dim(
+ sharding_spec=output_sharding_spec,
+ dim_mapping=dim_mapping,
+ physical_shape=output_op_data.data.shape,
+ inplace=True,
+ )
sharding_strategies.append(strategy_copy)
return sharding_strategies
@@ -125,14 +131,16 @@ class EmbeddingModuleHandler(ModuleHandler):
# Finally, the input will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1).shape
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=input_meta_data,
- logical_shape=input_logical_shape)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]),
+ type=OperationDataType.ARG,
+ data=input_meta_data,
+ logical_shape=input_logical_shape,
+ )
- physical_other_operand = OperationData(name="weight",
- type=OperationDataType.PARAM,
- data=self.named_parameters['weight'])
+ physical_other_operand = OperationData(
+ name="weight", type=OperationDataType.PARAM, data=self.named_parameters["weight"]
+ )
# Same as input, in nn.Embedding operation, all the dimensions of output will be treated as
# (batch dimension, embedding dimension), and then the sharding spec will be generated based
@@ -141,10 +149,12 @@ class EmbeddingModuleHandler(ModuleHandler):
# Finally, the output will be transformed back to its original shape in self.post_process
output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
- physical_output = OperationData(name=str(self.node),
- type=OperationDataType.OUTPUT,
- data=output_meta_data,
- logical_shape=output_logical_shape)
+ physical_output = OperationData(
+ name=str(self.node),
+ type=OperationDataType.OUTPUT,
+ data=output_meta_data,
+ logical_shape=output_logical_shape,
+ )
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
@@ -155,12 +165,11 @@ class EmbeddingModuleHandler(ModuleHandler):
Convert the sharding spec from the logical shape to the physical shape.
"""
# create multiple sharding strategies for the inputs
- # as input can be multi-dimensinal and the partition dim is only 2D,
+ # as input can be multi-dimensional and the partition dim is only 2D,
# we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output
- strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy,
- input_name=str(
- self.node.args[0]),
- output_name=str(self.node))
+ strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(
+ strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node)
+ )
return strategies
@@ -183,10 +192,12 @@ class EmbeddingFunctionHandler(NodeHandler):
# Finally, the input will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1).shape
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data,
- logical_shape=input_logical_shape)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]),
+ type=OperationDataType.ARG,
+ data=self.node.args[0]._meta_data,
+ logical_shape=input_logical_shape,
+ )
# check if the other operand is a parameter
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
@@ -194,9 +205,9 @@ class EmbeddingFunctionHandler(NodeHandler):
else:
data_type = OperationDataType.ARG
- physical_other_operand = OperationData(name=str(self.node.args[1]),
- type=data_type,
- data=self.node.args[1]._meta_data)
+ physical_other_operand = OperationData(
+ name=str(self.node.args[1]), type=data_type, data=self.node.args[1]._meta_data
+ )
# Same as input, in F.embedding operation, all the dimensions of output will be treated as
# (batch dimension, embedding dimension), and then the sharding spec will be generated based
@@ -221,10 +232,9 @@ class EmbeddingFunctionHandler(NodeHandler):
Convert the sharding spec from the logical shape to the physical shape.
"""
# create multiple sharding strategies for the inputs
- # as input can be multi-dimensinal and the partition dim is only 2D,
+ # as input can be multi-dimensional and the partition dim is only 2D,
# we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output
- strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy,
- input_name=str(
- self.node.args[0]),
- output_name=str(self.node))
+ strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(
+ strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node)
+ )
return strategies
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py
index 53addb873d1d1a014352058f8ec127f6bf7c4d91..dcf0a1760a2cdbf3791bf13103dd8940b0c1f24f 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py
@@ -4,7 +4,7 @@ from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .strategy import GetattrGenerator, StrategyGenerator
-__all__ = ['GetattrHandler']
+__all__ = ["GetattrHandler"]
class GetattrHandler(NodeHandler):
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py
index 3466e9dd9940e748da4bc8abb3488aacf98cd8ff..bd342c12eda97d83217b6f4d8b6b846974b67657 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py
@@ -8,7 +8,7 @@ from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator
-__all__ = ['GetItemHandler']
+__all__ = ["GetItemHandler"]
@operator_registry.register(operator.getitem)
@@ -30,9 +30,9 @@ class GetItemHandler(NodeHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
+ )
physical_other_operand = OperationData(name="index", type=OperationDataType.ARG, data=self.node.args[1])
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py
index 452381169b74d093188e0f8d7775037f8bf5019c..ce6b20fa1d2407531c6461880de866161eac85d4 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py
@@ -3,11 +3,11 @@ from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
-from .node_handler import MetaInfoModuleHandler, ModuleHandler
+from .node_handler import MetaInfoModuleHandler
from .registry import operator_registry
from .strategy import LayerNormGenerator, StrategyGenerator
-__all__ = ['LayerNormModuleHandler']
+__all__ = ["LayerNormModuleHandler"]
@operator_registry.register(torch.nn.LayerNorm)
@@ -25,20 +25,22 @@ class LayerNormModuleHandler(MetaInfoModuleHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data)
- physical_other_operand = OperationData(name="weight",
- type=OperationDataType.PARAM,
- data=self.named_parameters['weight'],
- logical_shape=self.named_parameters['weight'].shape)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
+ )
+ physical_other_operand = OperationData(
+ name="weight",
+ type=OperationDataType.PARAM,
+ data=self.named_parameters["weight"],
+ logical_shape=self.named_parameters["weight"].shape,
+ )
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
- if self.named_parameters['bias'] is not None:
- physical_bias_operand = OperationData(name="bias",
- type=OperationDataType.PARAM,
- data=self.named_parameters['bias'])
- mapping['bias'] = physical_bias_operand
+ if self.named_parameters["bias"] is not None:
+ physical_bias_operand = OperationData(
+ name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"]
+ )
+ mapping["bias"] = physical_bias_operand
return mapping
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
index 59091dab519f4e4458461b84e444b9a034f4df98..4177af4eaf71b21bd0ec9ae8f386d95b212e41a1 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
@@ -3,27 +3,24 @@ from typing import Dict, List, Union
import torch
import torch.nn.functional as F
-from colossalai.auto_parallel.tensor_shard.utils import (
- check_sharding_spec_validity,
- transpose_partition_dim,
- update_partition_dim,
-)
+from colossalai.auto_parallel.tensor_shard.utils import transpose_partition_dim, update_partition_dim
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
-from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
-from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler
+from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
+from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler
from .registry import operator_registry
from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator
-__all__ = ['LinearModuleHandler', 'LinearFunctionHandler']
+__all__ = ["LinearModuleHandler", "LinearFunctionHandler"]
-def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStrategy,
- weight_name: str) -> ShardingStrategy:
+def _update_sharding_spec_for_transposed_weight_for_linear(
+ strategy: ShardingStrategy, weight_name: str
+) -> ShardingStrategy:
"""
This function is a helper function used by both module node handler and function node handler. This function will
- convert the sharding spec for the transposed weight to the correct partititon spec.
+ convert the sharding spec for the transposed weight to the correct partition spec.
Args:
strategy (ShardingStrategy): the strategy generated by the strategy generator.
@@ -32,16 +29,17 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr
# switch the dimensions of the transposed weight
sharding_spec = strategy.get_sharding_spec_by_name(weight_name)
op_data = strategy.get_op_data_by_name(weight_name)
- assert op_data.logical_shape[0] == op_data.data.shape[1] and \
- op_data.logical_shape[1] == op_data.data.shape[0], \
- "Expected the logical shape of the linear operator's weight is equal to transposed physical shape"
+ assert (
+ op_data.logical_shape[0] == op_data.data.shape[1] and op_data.logical_shape[1] == op_data.data.shape[0]
+ ), "Expected the logical shape of the linear operator's weight is equal to transposed physical shape"
dim_size = len(op_data.logical_shape)
transpose_partition_dim(sharding_spec, 0, dim_size - 1)
return strategy
-def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: ShardingStrategy, input_name: str,
- output_name: str) -> List[ShardingStrategy]:
+def _convert_logical_sharding_to_physical_sharding_spec_for_linear(
+ strategy: ShardingStrategy, input_name: str, output_name: str
+) -> List[ShardingStrategy]:
"""
This function converts the logical sharding spec to the physical sharding spec for both the input and output of the linear operation. The input and output
should have the same sharding spec.
@@ -99,22 +97,26 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
input_dim_mapping = {0: i}
input_dim_mapping.update(input_last_dim_mapping)
- update_partition_dim(sharding_spec=input_sharding_spec,
- dim_mapping=input_dim_mapping,
- physical_shape=input_op_data.data.shape,
- inplace=True)
+ update_partition_dim(
+ sharding_spec=input_sharding_spec,
+ dim_mapping=input_dim_mapping,
+ physical_shape=input_op_data.data.shape,
+ inplace=True,
+ )
output_dim_mapping = {0: i}
output_dim_mapping.update(output_last_dim_mapping)
- update_partition_dim(sharding_spec=output_sharding_spec,
- dim_mapping=output_dim_mapping,
- physical_shape=output_op_data.data.shape,
- inplace=True)
- strategy_copy.name = f'{strategy.name}_{i}'
+ update_partition_dim(
+ sharding_spec=output_sharding_spec,
+ dim_mapping=output_dim_mapping,
+ physical_shape=output_op_data.data.shape,
+ inplace=True,
+ )
+ strategy_copy.name = f"{strategy.name}_{i}"
sharding_strategies.append(strategy_copy)
except ShardingNotDivisibleError as e:
logger.debug(
- f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}'
+ f"Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}"
)
else:
# the generated sharding strategy does not shard the non-matrix dimension,
@@ -127,17 +129,21 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
# after updating, the logical shape will be replaced by the physical shape
input_dim_mapping = {}
input_dim_mapping.update(input_last_dim_mapping)
- update_partition_dim(sharding_spec=input_sharding_spec,
- dim_mapping=input_dim_mapping,
- physical_shape=input_op_data.data.shape,
- inplace=True)
+ update_partition_dim(
+ sharding_spec=input_sharding_spec,
+ dim_mapping=input_dim_mapping,
+ physical_shape=input_op_data.data.shape,
+ inplace=True,
+ )
output_dim_mapping = {}
output_dim_mapping.update(output_last_dim_mapping)
- update_partition_dim(sharding_spec=output_sharding_spec,
- dim_mapping=output_dim_mapping,
- physical_shape=output_op_data.data.shape,
- inplace=True)
+ update_partition_dim(
+ sharding_spec=output_sharding_spec,
+ dim_mapping=output_dim_mapping,
+ physical_shape=output_op_data.data.shape,
+ inplace=True,
+ )
sharding_strategies.append(strategy_copy)
return sharding_strategies
@@ -152,10 +158,13 @@ class LinearModuleHandler(MetaInfoModuleHandler):
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(
- LinearProjectionStrategyGenerator(op_data_mapping,
- self.device_mesh,
- linear_projection_type='linear',
- solver_perference=self.solver_perference))
+ LinearProjectionStrategyGenerator(
+ op_data_mapping,
+ self.device_mesh,
+ linear_projection_type="linear",
+ solver_perference=self.solver_perference,
+ )
+ )
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
@@ -163,28 +172,34 @@ class LinearModuleHandler(MetaInfoModuleHandler):
# the strategies will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=input_meta_data,
- logical_shape=input_logical_shape)
- physical_other_operand = OperationData(name="weight",
- type=OperationDataType.PARAM,
- data=self.named_parameters['weight'],
- logical_shape=self.named_parameters['weight'].shape[::-1])
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]),
+ type=OperationDataType.ARG,
+ data=input_meta_data,
+ logical_shape=input_logical_shape,
+ )
+ physical_other_operand = OperationData(
+ name="weight",
+ type=OperationDataType.PARAM,
+ data=self.named_parameters["weight"],
+ logical_shape=self.named_parameters["weight"].shape[::-1],
+ )
output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
- physical_output = OperationData(name=str(self.node),
- type=OperationDataType.OUTPUT,
- data=output_meta_data,
- logical_shape=output_logical_shape)
+ physical_output = OperationData(
+ name=str(self.node),
+ type=OperationDataType.OUTPUT,
+ data=output_meta_data,
+ logical_shape=output_logical_shape,
+ )
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
- if 'bias' in self.named_parameters is not None:
- physical_bias_operand = OperationData(name="bias",
- type=OperationDataType.PARAM,
- data=self.named_parameters['bias'])
- mapping['bias'] = physical_bias_operand
+ if "bias" in self.named_parameters is not None:
+ physical_bias_operand = OperationData(
+ name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"]
+ )
+ mapping["bias"] = physical_bias_operand
return mapping
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
@@ -194,14 +209,14 @@ class LinearModuleHandler(MetaInfoModuleHandler):
2. the input and output sharding specs are updated to physical shape.
"""
# switch the dimensions of the transposed weight
- strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name='weight')
+ strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name="weight")
# create multiple sharding strategies for the inputs
- # as input can be multi-dimensinal and the partition dim is only 2D,
+ # as input can be multi-dimensional and the partition dim is only 2D,
# we need to map the partition at dim 0 to one of the first few dimensions of the input
- strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy,
- input_name=str(self.node.args[0]),
- output_name=str(self.node))
+ strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(
+ strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node)
+ )
return strategies
@@ -215,7 +230,8 @@ class LinearFunctionHandler(MetaInfoNodeHandler):
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(
- LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
+ LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type="linear")
+ )
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
@@ -223,10 +239,12 @@ class LinearFunctionHandler(MetaInfoNodeHandler):
# the strategies will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data,
- logical_shape=input_logical_shape)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]),
+ type=OperationDataType.ARG,
+ data=self.node.args[0]._meta_data,
+ logical_shape=input_logical_shape,
+ )
# check if the other operand is a parameter
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
@@ -234,10 +252,12 @@ class LinearFunctionHandler(MetaInfoNodeHandler):
else:
data_type = OperationDataType.ARG
- physical_other_operand = OperationData(name=str(self.node.args[1]),
- type=data_type,
- data=self.node.args[1]._meta_data,
- logical_shape=self.node.args[1]._meta_data.shape[::-1])
+ physical_other_operand = OperationData(
+ name=str(self.node.args[1]),
+ type=data_type,
+ data=self.node.args[1]._meta_data,
+ logical_shape=self.node.args[1]._meta_data.shape[::-1],
+ )
output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
physical_output = OperationData(
@@ -249,27 +269,28 @@ class LinearFunctionHandler(MetaInfoNodeHandler):
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
- if 'bias' in self.node.kwargs and self.node.kwargs['bias'] is not None:
+ if "bias" in self.node.kwargs and self.node.kwargs["bias"] is not None:
# check if the other operand is a parameter
if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
- physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]),
- type=data_type,
- data=self.node.kwargs["bias"]._meta_data)
- mapping['bias'] = physical_bias_operand
+ physical_bias_operand = OperationData(
+ name=str(self.node.kwargs["bias"]), type=data_type, data=self.node.kwargs["bias"]._meta_data
+ )
+ mapping["bias"] = physical_bias_operand
return mapping
def post_process(self, strategy: ShardingStrategy):
# switch the dimensions of the transposed weight
- strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy,
- weight_name=str(self.node.args[1]))
+ strategy = _update_sharding_spec_for_transposed_weight_for_linear(
+ strategy=strategy, weight_name=str(self.node.args[1])
+ )
# create multiple sharding strategies for the inputs
- # as input can be multi-dimensinal and the partition dim is only 2D,
+ # as input can be multi-dimensional and the partition dim is only 2D,
# we need to map the partition at dim 0 to one of the first few dimensions of the input
- strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy,
- input_name=str(self.node.args[0]),
- output_name=str(self.node))
+ strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(
+ strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node)
+ )
return strategies
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py
index f3c9d0cbf8267e2321415ae29887e308a9af35b2..4fab5f7f05eb3ea4dcfeeb49a48ae8230acef97e 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py
@@ -16,7 +16,7 @@ from colossalai.tensor.sharding_spec import ShardingSpecException
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import recover_sharding_spec_for_broadcast_shape
-from .node_handler import MetaInfoNodeHandler, NodeHandler
+from .node_handler import MetaInfoNodeHandler
from .registry import operator_registry
from .strategy import (
BatchedMatMulStrategyGenerator,
@@ -37,6 +37,7 @@ class MatMulType(Enum):
MV: matrix-vector product: the 1st tensor is 2D and the 2nd tensor is 1D
BMM: batched matrix-matrix multiplication, one tensor is at least 1D and the other is at least 3D
"""
+
DOT = 0
MM = 1
MV = 2
@@ -48,8 +49,8 @@ def get_matmul_type(input_dim: int, other_dim: int):
Determine which type of matmul operation should be executed for the given tensor dimensions.
Args:
- input_dim (int): the number of dimensions for the input tenosr
- other_dim (int): the number of dimensions for the other tenosr
+ input_dim (int): the number of dimensions for the input tensor
+ other_dim (int): the number of dimensions for the other tensor
"""
if input_dim == 1 and other_dim == 1:
matmul_type = MatMulType.DOT
@@ -92,26 +93,26 @@ class Padder(BmmTransform):
def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = deepcopy(shape_mapping)
- input_shape = mapping_copy['input']
- other_shape = mapping_copy['other']
+ input_shape = mapping_copy["input"]
+ other_shape = mapping_copy["other"]
if len(input_shape) == 1:
# if the input is a 1D tensor, 1 is prepended to its shape
# and it will be removed afterwards
input_shape.insert(0, 1)
- self.padded_dim_mapping['input'] = -2
- self.padded_dim_mapping['output'] = -2
+ self.padded_dim_mapping["input"] = -2
+ self.padded_dim_mapping["output"] = -2
elif len(other_shape) == 1:
# if the other is a 1D tensor, 1 is appended to its shape
# and it will be removed afterwards
other_shape = other_shape.append(1)
- self.padded_dim_mapping['other'] = -1
- self.padded_dim_mapping['output'] = -1
+ self.padded_dim_mapping["other"] = -1
+ self.padded_dim_mapping["output"] = -1
return mapping_copy
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
- input_op_data = op_data_mapping['input']
- other_op_data = op_data_mapping['other']
+ op_data_mapping["input"]
+ op_data_mapping["other"]
def _remove_padded_dim(key, strategy):
op_data = op_data_mapping[key]
@@ -131,7 +132,7 @@ class Padder(BmmTransform):
# compute unpadded tensor shape
tensor_shape.pop(padded_dim)
- assert tensor_shape == list(op_data.data.shape), f'{tensor_shape} vs {list(op_data.data.shape)}'
+ assert tensor_shape == list(op_data.data.shape), f"{tensor_shape} vs {list(op_data.data.shape)}"
# update sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, tensor_shape, unpadded_dim_partition_list)
@@ -142,15 +143,15 @@ class Padder(BmmTransform):
strategy_copy = strategy.clone()
# only one of input and other will be padded
- if 'input' in self.padded_dim_mapping:
- _remove_padded_dim('input', strategy_copy)
- _remove_padded_dim('output', strategy_copy)
- elif 'other' in self.padded_dim_mapping:
- _remove_padded_dim('other', strategy_copy)
- _remove_padded_dim('output', strategy_copy)
+ if "input" in self.padded_dim_mapping:
+ _remove_padded_dim("input", strategy_copy)
+ _remove_padded_dim("output", strategy_copy)
+ elif "other" in self.padded_dim_mapping:
+ _remove_padded_dim("other", strategy_copy)
+ _remove_padded_dim("output", strategy_copy)
strategies.append(strategy_copy)
- except ShardingSpecException as e:
+ except ShardingSpecException:
pass
return strategies
@@ -167,8 +168,8 @@ class Broadcaster(BmmTransform):
mapping_copy = shape_mapping.copy()
# get shapes
- input_shape = mapping_copy['input']
- other_shape = mapping_copy['other']
+ input_shape = mapping_copy["input"]
+ other_shape = mapping_copy["other"]
# sanity check
assert len(input_shape) > 1 and len(other_shape) > 1
@@ -179,16 +180,16 @@ class Broadcaster(BmmTransform):
# store the broadcast dim info
input_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, input_shape[:-2])
other_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, other_shape[:-2])
- self.broadcast_dim_info['input'] = input_broadcast_dim_info
- self.broadcast_dim_info['other'] = other_broadcast_dim_info
+ self.broadcast_dim_info["input"] = input_broadcast_dim_info
+ self.broadcast_dim_info["other"] = other_broadcast_dim_info
# create the full logical shape
input_shape = bcast_non_matrix_dims + input_shape[-2:]
other_shape = bcast_non_matrix_dims + other_shape[-2:]
assert len(input_shape) == len(other_shape)
- mapping_copy['input'] = input_shape
- mapping_copy['other'] = other_shape
+ mapping_copy["input"] = input_shape
+ mapping_copy["other"] = other_shape
return mapping_copy
@@ -206,7 +207,7 @@ class Broadcaster(BmmTransform):
# e.g. [1, 2, 4] x [4, 4, 8] -> [4, 2, 8]
# the dim 0 of [1, 2, 4] is multiplied to 4
tensor_shape[dim_idx] = 1
- elif broadcast_type == BroadcastType.PADDDING:
+ elif broadcast_type == BroadcastType.PADDING:
# if the dim is padded
# we remove its sharding
tensor_shape[dim_idx] = None
@@ -216,17 +217,18 @@ class Broadcaster(BmmTransform):
physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
logical_sharding_spec=sharding_spec,
logical_shape=sharding_spec.entire_shape,
- physical_shape=tensor_shape_before_broadcast)
+ physical_shape=tensor_shape_before_broadcast,
+ )
strategy.sharding_specs[op_data] = physical_sharding_spec
# enumerate all sharding strategies
strategies = []
try:
strategy_copy = strategy.clone()
- _remove_sharding_on_broadcast_dim('input', strategy_copy)
- _remove_sharding_on_broadcast_dim('other', strategy_copy)
+ _remove_sharding_on_broadcast_dim("input", strategy_copy)
+ _remove_sharding_on_broadcast_dim("other", strategy_copy)
strategies.append(strategy_copy)
- except ShardingSpecException as e:
+ except ShardingSpecException:
pass
return strategies
@@ -241,20 +243,20 @@ class Viewer(BmmTransform):
def apply(self, shape_mapping: Dict[str, List[int]]):
mapping_copy = shape_mapping.copy()
- self.batch_dims_before_view = list(mapping_copy['input'][:-2])
+ self.batch_dims_before_view = list(mapping_copy["input"][:-2])
# get shapes
- input_shape = shape_mapping['input']
- other_shape = shape_mapping['other']
+ input_shape = shape_mapping["input"]
+ other_shape = shape_mapping["other"]
# view to 3d tensor
assert len(input_shape) >= 3 and len(other_shape) >= 3
input_shape = [reduce(operator.mul, input_shape[:-2])] + input_shape[-2:]
other_shape = [reduce(operator.mul, other_shape[:-2])] + other_shape[-2:]
output_shape = input_shape[:2] + other_shape[2:]
- mapping_copy['input'] = input_shape
- mapping_copy['other'] = other_shape
- mapping_copy['output'] = output_shape
+ mapping_copy["input"] = input_shape
+ mapping_copy["other"] = other_shape
+ mapping_copy["output"] = output_shape
return mapping_copy
def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy):
@@ -268,13 +270,13 @@ class Viewer(BmmTransform):
dim_partition_dict = sharding_spec.dim_partition_dict
entire_shape = sharding_spec.entire_shape
- # upddate the dimension index for the matrix dimensions
+ # update the dimension index for the matrix dimensions
if 2 in dim_partition_dict:
dim_partition_dict[len(self.batch_dims_before_view) + 1] = dim_partition_dict.pop(2)
if 1 in dim_partition_dict:
dim_partition_dict[len(self.batch_dims_before_view)] = dim_partition_dict.pop(1)
- # map the logical batch dim to phyiscal batch dim
+ # map the logical batch dim to physical batch dim
if 0 in dim_partition_dict:
batch_dim_shard = dim_partition_dict.pop(0)
dim_partition_dict[physical_batch_dim] = batch_dim_shard
@@ -291,11 +293,11 @@ class Viewer(BmmTransform):
# create a new strategy
strategy_copy = strategy.clone()
try:
- _update_sharding_spec('input', strategy_copy, i)
- _update_sharding_spec('other', strategy_copy, i)
- _update_sharding_spec('output', strategy_copy, i)
+ _update_sharding_spec("input", strategy_copy, i)
+ _update_sharding_spec("other", strategy_copy, i)
+ _update_sharding_spec("output", strategy_copy, i)
strategies.append(strategy_copy)
- except ShardingSpecException as e:
+ except ShardingSpecException:
continue
return strategies
@@ -312,14 +314,14 @@ def _get_bmm_logical_shape(input_shape, other_shape, transforms):
3. reshape to 3 dimensions
"""
- shape_mapping = {'input': input_shape, 'other': other_shape}
+ shape_mapping = {"input": input_shape, "other": other_shape}
for transform in transforms:
shape_mapping = transform.apply(shape_mapping)
- input_shape = shape_mapping.get('input', None)
- other_shape = shape_mapping.get('other', None)
- output_shape = shape_mapping.get('output', None)
+ input_shape = shape_mapping.get("input", None)
+ other_shape = shape_mapping.get("other", None)
+ output_shape = shape_mapping.get("output", None)
return input_shape, other_shape, output_shape
@@ -364,7 +366,8 @@ class MatMulHandler(MetaInfoNodeHandler):
generators.append(MatVecStrategyGenerator(op_data_mapping, self.device_mesh))
elif self.matmul_type == MatMulType.MM:
generators.append(
- LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
+ LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type="linear")
+ )
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
@@ -372,7 +375,7 @@ class MatMulHandler(MetaInfoNodeHandler):
MatMulType.DOT: self._get_logical_shape_for_dot,
MatMulType.MM: self._get_logical_shape_for_mm,
MatMulType.MV: self._get_logical_shape_for_mv,
- MatMulType.BMM: self._get_logical_shape_for_bmm
+ MatMulType.BMM: self._get_logical_shape_for_bmm,
}
logical_shapes = logical_shape_func[self.matmul_type]()
op_data_mapping = self._get_op_data_mapping(*logical_shapes)
@@ -390,20 +393,26 @@ class MatMulHandler(MetaInfoNodeHandler):
output_logical_shape = torch.Size(output_logical_shape)
# create op data
- input_op_data = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.input_meta_data,
- logical_shape=input_logical_shape)
- other_op_data = OperationData(name=str(self.node.args[1]),
- type=OperationDataType.ARG,
- data=self.other_meta_data,
- logical_shape=other_logical_shape)
- output_op_data = OperationData(name=str(self.node),
- type=OperationDataType.OUTPUT,
- data=self.output_meta_data,
- logical_shape=output_logical_shape)
-
- mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
+ input_op_data = OperationData(
+ name=str(self.node.args[0]),
+ type=OperationDataType.ARG,
+ data=self.input_meta_data,
+ logical_shape=input_logical_shape,
+ )
+ other_op_data = OperationData(
+ name=str(self.node.args[1]),
+ type=OperationDataType.ARG,
+ data=self.other_meta_data,
+ logical_shape=other_logical_shape,
+ )
+ output_op_data = OperationData(
+ name=str(self.node),
+ type=OperationDataType.OUTPUT,
+ data=self.output_meta_data,
+ logical_shape=output_logical_shape,
+ )
+
+ mapping = {"input": input_op_data, "other": other_op_data, "output": output_op_data}
return mapping
def _get_logical_shape_for_dot(self):
@@ -414,7 +423,7 @@ class MatMulHandler(MetaInfoNodeHandler):
def _get_logical_shape_for_mm(self):
"""
- We need to handle the input tensor for a matrix-matrix multiplcation as the input
+ We need to handle the input tensor for a matrix-matrix multiplication as the input
tensor can be a 1D or 2D tensor. If it is a 1D tensor, 1 will be prepended to its shape
(e.g. [4] -> [1, 4]).
"""
@@ -460,9 +469,11 @@ class MatMulHandler(MetaInfoNodeHandler):
dim_partition_dict[0] = shard
# re-init the sharding spec
- input_sharding_spec.__init__(input_sharding_spec.device_mesh,
- entire_shape=input_physical_shape,
- dim_partition_dict=dim_partition_dict)
+ input_sharding_spec.__init__(
+ input_sharding_spec.device_mesh,
+ entire_shape=input_physical_shape,
+ dim_partition_dict=dim_partition_dict,
+ )
return strategy
else:
return strategy
@@ -481,7 +492,8 @@ class MatMulHandler(MetaInfoNodeHandler):
recovered_stragies.extend(output)
else:
raise TypeError(
- f"Found unexpected output type {type(output)} from the recover method of BmmTransform")
+ f"Found unexpected output type {type(output)} from the recover method of BmmTransform"
+ )
strategies = recovered_stragies
for index, strategies in enumerate(strategies):
strategies.name = f"{strategies.name}_{index}"
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
index ab391ebfaf80960ef49a4e9c4761c76f82567d25..d2bad39dcbb9164bd6c3e6e4d42323d25b88d385 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
@@ -8,7 +8,6 @@ from colossalai.auto_parallel.meta_profiler.shard_metainfo import ShardMetaInfo,
from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
- OperationDataType,
ShardingSpec,
ShardingStrategy,
StrategiesVector,
@@ -23,21 +22,23 @@ from .strategy import StrategyGenerator
class NodeHandler(ABC):
- '''
+ """
The NodeHandler is an abstract class used to generate every possible strategies for an operator node.
Args:
node (Node): the input node in node argument list.
device_mesh (DeviceMesh): A logical view of a physical mesh.
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
- '''
-
- def __init__(self,
- node: Node,
- device_mesh: DeviceMesh,
- strategies_vector: StrategiesVector,
- shard_option: ShardOption = ShardOption.STANDARD,
- solver_perference: SolverPerference = SolverPerference.STANDARD) -> None:
+ """
+
+ def __init__(
+ self,
+ node: Node,
+ device_mesh: DeviceMesh,
+ strategies_vector: StrategiesVector,
+ shard_option: ShardOption = ShardOption.STANDARD,
+ solver_perference: SolverPerference = SolverPerference.STANDARD,
+ ) -> None:
self.node = node
self.predecessor_node = list(node._input_nodes.keys())
self.successor_node = list(node.users.keys())
@@ -68,22 +69,23 @@ class NodeHandler(ABC):
current_sharding_spec = strategy.sharding_specs[op_data]
# get the sharding specs for this node generated
# in its own node handler
- assert hasattr(node, 'strategies_vector'), \
- f'The predecessor node {node_name} has no strategy vector to compute the resharding cost.'
+ assert hasattr(
+ node, "strategies_vector"
+ ), f"The predecessor node {node_name} has no strategy vector to compute the resharding cost."
prev_strategy_vector = node.strategies_vector
prev_sharding_specs = [
prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector
]
- # create data structrure to store costs
+ # create data structure to store costs
if node not in resharding_costs:
resharding_costs[node] = []
def _compute_resharding_cost(
- prev_sharding_spec: Union[ShardingSpec,
- List[ShardingSpec]], current_sharding_spec: Union[ShardingSpec,
- List[ShardingSpec]],
- data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]) -> TrainCycleItem:
+ prev_sharding_spec: Union[ShardingSpec, List[ShardingSpec]],
+ current_sharding_spec: Union[ShardingSpec, List[ShardingSpec]],
+ data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
+ ) -> TrainCycleItem:
"""
This is a helper function to compute the resharding cost for a specific strategy of a node.
"""
@@ -94,30 +96,35 @@ class NodeHandler(ABC):
dtype = data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
_, _, consistency_cost = shape_consistency_manager.shape_consistency(
- prev_sharding_spec, current_sharding_spec)
-
- resharding_cost = TrainCycleItem(fwd=consistency_cost["forward"] * size_per_elem_bytes,
- bwd=consistency_cost["backward"] * size_per_elem_bytes,
- total=consistency_cost["total"] * size_per_elem_bytes)
+ prev_sharding_spec, current_sharding_spec
+ )
+
+ resharding_cost = TrainCycleItem(
+ fwd=consistency_cost["forward"] * size_per_elem_bytes,
+ bwd=consistency_cost["backward"] * size_per_elem_bytes,
+ total=consistency_cost["total"] * size_per_elem_bytes,
+ )
return resharding_cost
else:
# This raise is used to check if we have missed any type of data.
# It could be merged into Parameter branch, which means we won't handle
# non-tensor arguments.
- raise ValueError(f'Unsupported data type {type(data)}')
+ raise ValueError(f"Unsupported data type {type(data)}")
else:
- assert isinstance(prev_sharding_spec, (tuple, list)), \
- f'prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \
- or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}'
+ assert isinstance(
+ prev_sharding_spec, (tuple, list)
+ ), f"prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \
+ or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}"
fwd_cost = 0
bwd_cost = 0
total_cost = 0
- for index, (prev_sharding_spec_item,
- current_sharding_spec_item) in enumerate(zip(prev_sharding_spec,
- current_sharding_spec)):
- item_cost = _compute_resharding_cost(prev_sharding_spec_item, current_sharding_spec_item,
- data[index])
+ for index, (prev_sharding_spec_item, current_sharding_spec_item) in enumerate(
+ zip(prev_sharding_spec, current_sharding_spec)
+ ):
+ item_cost = _compute_resharding_cost(
+ prev_sharding_spec_item, current_sharding_spec_item, data[index]
+ )
fwd_cost += item_cost.fwd
bwd_cost += item_cost.bwd
total_cost += item_cost.total
@@ -138,17 +145,17 @@ class NodeHandler(ABC):
This function is used to get the target function for the node handler.
The target function is used to analyze the costs of strategies.
"""
- if self.node.op in ('placeholder', 'get_attr', 'output'):
+ if self.node.op in ("placeholder", "get_attr", "output"):
return None
- if self.node.op == 'call_module':
+ if self.node.op == "call_module":
target = self.node.graph.owning_module.get_submodule(self.node.target)
- elif self.node.op == 'call_function':
+ elif self.node.op == "call_function":
target = self.node.target
- elif self.node.op == 'call_method':
+ elif self.node.op == "call_method":
target = getattr(self.node.args[0]._meta_data.__class__, self.node.target)
else:
- raise ValueError(f'Unsupported node type: {self.node.op}')
+ raise ValueError(f"Unsupported node type: {self.node.op}")
return target
@@ -188,7 +195,7 @@ class NodeHandler(ABC):
remove_strategy_list = []
for strategy in self.strategies_vector:
shard_axis_list = []
- last_axis = len(self.device_mesh.mesh_shape) - 1
+ last_axis = len(self.device_mesh.shape) - 1
for op_data, sharding_spec in strategy.sharding_specs.items():
if op_data.data is not None and isinstance(op_data.data, torch.Tensor):
for dim, shard_axes in sharding_spec.dim_partition_dict.items():
@@ -212,7 +219,7 @@ class NodeHandler(ABC):
return self.strategies_vector
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
- # tranform the strategy generated
+ # transform the strategy generated
# e.g. to process the sharding strategy for the transposed weights
return strategy
@@ -221,7 +228,6 @@ class NodeHandler(ABC):
"""
Define which generators should be used by this NodeHandler object.
"""
- pass
@abstractmethod
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
@@ -244,7 +250,6 @@ class NodeHandler(ABC):
"output": Operand(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data),
}
"""
- pass
class MetaInfoNodeHandler(NodeHandler):
@@ -278,19 +283,19 @@ class MetaInfoNodeHandler(NodeHandler):
else:
logger = get_dist_logger()
- logger.warning(f'The target function {target} is not patched yet, ')
+ logger.warning(f"The target function {target} is not patched yet, ")
return self.strategies_vector
class ModuleHandler(NodeHandler):
-
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# set attributes to access module parameters for convenience
- assert self.node.graph.owning_module is not None, \
- f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.'
+ assert (
+ self.node.graph.owning_module is not None
+ ), f"The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object."
module = self.node.graph.owning_module.get_submodule(self.node.target)
named_parameters = list(module.named_parameters(recurse=False))
named_buffers = list(module.named_buffers(recurse=False))
@@ -333,6 +338,6 @@ class MetaInfoModuleHandler(ModuleHandler):
else:
logger = get_dist_logger()
- logger.warning(f'The target function {target} is not patched yet')
+ logger.warning(f"The target function {target} is not patched yet")
return self.strategies_vector
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py
index 4e71ccba95a7e6457309a455986400dc49893d18..facf19560596020138cd287d403539914ae3a98e 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py
@@ -3,11 +3,11 @@ from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
-from .node_handler import MetaInfoModuleHandler, ModuleHandler
+from .node_handler import MetaInfoModuleHandler
from .registry import operator_registry
from .strategy import NormalPoolStrategyGenerator, StrategyGenerator
-__all__ = ['NormPoolingHandler']
+__all__ = ["NormPoolingHandler"]
@operator_registry.register(torch.nn.MaxPool1d)
@@ -30,9 +30,9 @@ class NormPoolingHandler(MetaInfoModuleHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
+ )
physical_weight_operand = OperationData(name="kernel", type=OperationDataType.ARG, data=self.module.kernel_size)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py
index ed120a8c3d6df9b5d10f44f2b86be1c3cf283c10..89906a205e87f2e955712999fb89805e66354f97 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py
@@ -8,7 +8,7 @@ from ..sharding_strategy import OperationData, OperationDataType, StrategiesVect
from .node_handler import NodeHandler
from .strategy import OutputGenerator, StrategyGenerator
-__all__ = ['OutputHandler']
+__all__ = ["OutputHandler"]
class OutputHandler(NodeHandler):
@@ -16,8 +16,9 @@ class OutputHandler(NodeHandler):
A OutputHandler which deals with the sharding strategies for Output Node.
"""
- def __init__(self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
- output_option: str) -> None:
+ def __init__(
+ self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, output_option: str
+ ) -> None:
super().__init__(node, device_mesh, strategies_vector)
self.output_option = output_option
@@ -35,11 +36,11 @@ class OutputHandler(NodeHandler):
for index, input_node in enumerate(self.predecessor_node):
input_meta_data = input_node._meta_data
physical_inputs = OperationData(name=str(input_node), type=OperationDataType.ARG, data=input_meta_data)
- name_key = f'input_{index}'
+ name_key = f"input_{index}"
mapping[name_key] = physical_inputs
output_meta_data.append(input_meta_data)
- assert len(output_meta_data) > 0, f'Output node {self.node} has no input node.'
+ assert len(output_meta_data) > 0, f"Output node {self.node} has no input node."
if len(output_meta_data) == 1:
output_meta_data = output_meta_data[0]
else:
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py
index 91e4a5105a08ff7d28cebba41f4962daa951259c..75f07168e47b440843e38ab8ad0b1abf2bc136ca 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py
@@ -7,7 +7,7 @@ from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import PermuteGenerator, StrategyGenerator
-__all__ = ['PermuteHandler']
+__all__ = ["PermuteHandler"]
@operator_registry.register(torch.Tensor.permute)
@@ -34,14 +34,14 @@ class PermuteHandler(NodeHandler):
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
permute_dims = []
- if self.node.op == 'call_method':
+ if self.node.op == "call_method":
# torch.Tensor.permute (input, *dims)
for arg in self.node.args:
if isinstance(arg, torch.fx.Node):
if isinstance(arg._meta_data, int):
permute_dims.append(arg._meta_data)
else:
- assert isinstance(arg, int), 'The argument in permute node should be either type of Node or int.'
+ assert isinstance(arg, int), "The argument in permute node should be either type of Node or int."
permute_dims.append(arg)
else:
# torch.permute (input, dims)
@@ -51,8 +51,8 @@ class PermuteHandler(NodeHandler):
permute_dims.extend(arg._meta_data)
else:
assert isinstance(
- arg,
- (tuple, list)), 'The argument in permute node should be type of Node, Tuple[int] or List[int].'
+ arg, (tuple, list)
+ ), "The argument in permute node should be type of Node, Tuple[int] or List[int]."
permute_dims.extend(arg)
num_dims = self.node._meta_data.dim()
@@ -61,7 +61,7 @@ class PermuteHandler(NodeHandler):
if permute_dims[i] < 0:
permute_dims[i] += num_dims
- physical_shape_operand = OperationData(name='permute_dims', type=OperationDataType.ARG, data=list(permute_dims))
+ physical_shape_operand = OperationData(name="permute_dims", type=OperationDataType.ARG, data=list(permute_dims))
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
@@ -69,7 +69,7 @@ class PermuteHandler(NodeHandler):
mapping = {
"input": physical_input_operand,
"permute_dims": physical_shape_operand,
- "output": physical_output_operand
+ "output": physical_output_operand,
}
return mapping
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py
index e4f40fc935a404dd8625c82fbb4dc7511c9fc839..461bc2935780e6d497469e2569a12abf43dff97b 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py
@@ -8,7 +8,7 @@ from ..sharding_strategy import OperationData, OperationDataType, StrategiesVect
from .node_handler import NodeHandler
from .strategy import PlaceholderGenerator, StrategyGenerator
-__all__ = ['PlaceholderHandler']
+__all__ = ["PlaceholderHandler"]
class PlaceholderHandler(NodeHandler):
@@ -16,8 +16,9 @@ class PlaceholderHandler(NodeHandler):
A PlaceholderHandler which deals with the sharding strategies for Placeholder Node.
"""
- def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
- placeholder_option: str) -> None:
+ def __init__(
+ self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, placeholder_option: str
+ ) -> None:
super().__init__(node, device_mesh, strategies_vector)
self.placeholder_option = placeholder_option
@@ -25,7 +26,8 @@ class PlaceholderHandler(NodeHandler):
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(
- PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option))
+ PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option)
+ )
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/registry.py b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py
index 8e06cec4f463a8600b2abe1a7f6713ec2ffb2931..f663fc9695d3d1b37591429adc80d71f744cdf12 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/registry.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py
@@ -1,12 +1,9 @@
class Registry:
- # TODO: refactor the registry classes used in colossalai.registry, colossalai.fx and here
-
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
-
def wrapper(func):
if isinstance(source, (list, tuple)):
# support register a list of items for this func
@@ -19,7 +16,7 @@ class Registry:
return wrapper
def get(self, source):
- assert source in self.store, f'{source} not found in the {self.name} registry'
+ assert source in self.store, f"{source} not found in the {self.name} registry"
target = self.store[source]
return target
@@ -27,4 +24,4 @@ class Registry:
return source in self.store
-operator_registry = Registry('operator')
+operator_registry = Registry("operator")
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py
index 743a1f90eaafa869b3a62882648cbde53f9e3166..6e883ea6473672204044b431fbe3a3265a63286b 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py
@@ -7,7 +7,7 @@ from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import SoftmaxGenerator, StrategyGenerator
-__all__ = ['SoftmaxHandler']
+__all__ = ["SoftmaxHandler"]
@operator_registry.register(torch.nn.Softmax)
@@ -34,14 +34,14 @@ class SoftmaxHandler(NodeHandler):
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
- softmax_dim = self.node.kwargs['dim']
+ softmax_dim = self.node.kwargs["dim"]
num_dims = self.node.args[0]._meta_data.dim()
# recover negative value to positive
if softmax_dim < 0:
softmax_dim += num_dims
- physical_dim_operand = OperationData(name='softmax_dim', type=OperationDataType.ARG, data=softmax_dim)
+ physical_dim_operand = OperationData(name="softmax_dim", type=OperationDataType.ARG, data=softmax_dim)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
@@ -49,7 +49,7 @@ class SoftmaxHandler(NodeHandler):
mapping = {
"input": physical_input_operand,
"softmax_dim": physical_dim_operand,
- "output": physical_output_operand
+ "output": physical_output_operand,
}
return mapping
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py
index 653d158b7c36ee1ff27791add2edad4093ce8675..4c32529a5d5b822dc89dbf97c5e19889153515af 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py
@@ -7,7 +7,7 @@ from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import SplitGenerator, StrategyGenerator
-__all__ = ['SplitHandler']
+__all__ = ["SplitHandler"]
@operator_registry.register(torch.Tensor.split)
@@ -38,7 +38,7 @@ class SplitHandler(NodeHandler):
split_dim = self.node.args[2]
else:
if self.node.kwargs:
- split_dim = self.node.kwargs['dim']
+ split_dim = self.node.kwargs["dim"]
else:
split_dim = 0
@@ -48,7 +48,7 @@ class SplitHandler(NodeHandler):
split_dim += num_dims
split_info = (split_size, split_dim)
- physical_shape_operand = OperationData(name='split_info', type=OperationDataType.ARG, data=split_info)
+ physical_shape_operand = OperationData(name="split_info", type=OperationDataType.ARG, data=split_info)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
@@ -56,7 +56,7 @@ class SplitHandler(NodeHandler):
mapping = {
"input": physical_input_operand,
"split_info": physical_shape_operand,
- "output": physical_output_operand
+ "output": physical_output_operand,
}
return mapping
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py
index db1f31521c86ef1842e93d9bbdbc58953e11934d..1fc7f613716b4cc895f10b61e9a5d0c91932a22e 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py
@@ -29,11 +29,31 @@ from .unary_elementwise_generator import UnaryElementwiseGenerator
from .where_generator import WhereGenerator
__all__ = [
- 'StrategyGenerator', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator',
- 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator',
- 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator',
- 'LayerNormGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator', 'NormalPoolStrategyGenerator',
- 'BinaryElementwiseStrategyGenerator', 'GetattrGenerator', 'TensorConstructorGenerator',
- 'EmbeddingStrategyGenerator', 'SumGenerator', 'SoftmaxGenerator', 'ViewGenerator', 'PermuteGenerator',
- 'TransposeGenerator', 'SplitGenerator', 'DefaultReshapeGenerator'
+ "StrategyGenerator",
+ "DotProductStrategyGenerator",
+ "MatVecStrategyGenerator",
+ "LinearProjectionStrategyGenerator",
+ "BatchedMatMulStrategyGenerator",
+ "ConvStrategyGenerator",
+ "UnaryElementwiseGenerator",
+ "BatchNormStrategyGenerator",
+ "GetItemStrategyGenerator",
+ "TensorStrategyGenerator",
+ "TensorTupleStrategyGenerator",
+ "LayerNormGenerator",
+ "PlaceholderGenerator",
+ "OutputGenerator",
+ "WhereGenerator",
+ "NormalPoolStrategyGenerator",
+ "BinaryElementwiseStrategyGenerator",
+ "GetattrGenerator",
+ "TensorConstructorGenerator",
+ "EmbeddingStrategyGenerator",
+ "SumGenerator",
+ "SoftmaxGenerator",
+ "ViewGenerator",
+ "PermuteGenerator",
+ "TransposeGenerator",
+ "SplitGenerator",
+ "DefaultReshapeGenerator",
]
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py
index 1f3812429fc274064163f5859d0fedb04f8115fb..9c766b1014c80454abe0123fc9a1464cf92b63af 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py
@@ -14,7 +14,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
-__all__ = ['BatchNormStrategyGenerator']
+__all__ = ["BatchNormStrategyGenerator"]
class BatchNormStrategyGenerator(StrategyGenerator):
@@ -24,34 +24,37 @@ class BatchNormStrategyGenerator(StrategyGenerator):
To keep the math consistency, there are two way to do BatchNorm if the input
shards on batch dimension:
1. We gather the input partitions through batch dimension, then do the normal BatchNorm.
- 2. We do the SyncBatchNorm on the each input partition seperately, the SyncBN op will help
+ 2. We do the SyncBatchNorm on the each input partition separately, the SyncBN op will help
us to keep the computing correctness.
In this generator, both methods will be considered.
"""
def validate(self) -> bool:
- '''
+ """
In sanity check, we need make sure the input data having correct dimension size.
For BatchNorm1d, the dim of input data should be 3([N, C, L]).
For BatchNorm2d, the dim of input data should be 4([N, C, H, W]).
For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]).
- '''
- input_op_data = self.op_data['input']
+ """
+ input_op_data = self.op_data["input"]
assert input_op_data.data.dim() in (
- 3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
+ 3,
+ 4,
+ 5,
+ ), f"We suppose the dim of input fed into conv op should in range of [3, 5]."
def update_compute_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the computation cost per device with this specific strategy.
- Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
- '''
+ Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
+ """
# TODO: a constant coefficient need to be added.
# 1D: (L) * N * Cin
# 2D: (H * W) * N * Cin
# 3D: (H * W * D) * N * Cin
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
- sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
+ sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
if self.has_bias:
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost = reduce(operator.mul, sharded_output_shape)
@@ -69,23 +72,24 @@ class BatchNormStrategyGenerator(StrategyGenerator):
def update_memory_cost(self, strategy: ShardingStrategy):
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'other': self._compute_size_in_bytes(strategy, "other"),
- 'output': self._compute_size_in_bytes(strategy, "output"),
- 'running_mean': self._compute_size_in_bytes(strategy, "running_mean"),
- 'running_var': self._compute_size_in_bytes(strategy, "running_var"),
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "other": self._compute_size_in_bytes(strategy, "other"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
+ "running_mean": self._compute_size_in_bytes(strategy, "running_mean"),
+ "running_var": self._compute_size_in_bytes(strategy, "running_var"),
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
- forward_size_mapping['bias'] = bias_size
+ forward_size_mapping["bias"] = bias_size
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + other + bias + output
fwd_activation_cost = sum(
- [v for k, v in forward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)])
+ [v for k, v in forward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)]
+ )
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_buffer_cost = sum([v for k, v in forward_size_mapping.items() if self.is_buffer(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost, buffer=fwd_buffer_cost)
@@ -93,36 +97,29 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# compute bwd cost incurred
# bwd_cost = input_grad + other_grad + bias_grad
bwd_activation_cost = sum(
- [v for k, v in backward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)])
+ [v for k, v in backward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)]
+ )
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost,
- buffer=fwd_buffer_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost,
+ parameter=fwd_parameter_cost + bwd_parameter_cost,
+ buffer=fwd_buffer_cost,
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def split_input_channel(self, mesh_dim_0):
- name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
+ name = f"RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}"
dim_partition_dict_mapping = {
- "input": {
- 1: [mesh_dim_0]
- },
- "other": {
- 0: [mesh_dim_0]
- },
- "output": {
- 1: [mesh_dim_0]
- },
- "running_mean": {
- 0: [mesh_dim_0]
- },
- "running_var": {
- 0: [mesh_dim_0]
- },
+ "input": {1: [mesh_dim_0]},
+ "other": {0: [mesh_dim_0]},
+ "output": {1: [mesh_dim_0]},
+ "running_mean": {0: [mesh_dim_0]},
+ "running_var": {0: [mesh_dim_0]},
"num_batches_tracked": {},
}
if self.has_bias:
@@ -132,29 +129,21 @@ class BatchNormStrategyGenerator(StrategyGenerator):
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
- name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
+ name = f"RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}"
dim_partition_dict_mapping = {
- "input": {
- 1: [mesh_dim_0, mesh_dim_1]
- },
- "other": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
- "output": {
- 1: [mesh_dim_0, mesh_dim_1]
- },
- "running_mean": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
- "running_var": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
+ "input": {1: [mesh_dim_0, mesh_dim_1]},
+ "other": {0: [mesh_dim_0, mesh_dim_1]},
+ "output": {1: [mesh_dim_0, mesh_dim_1]},
+ "running_mean": {0: [mesh_dim_0, mesh_dim_1]},
+ "running_var": {0: [mesh_dim_0, mesh_dim_1]},
"num_batches_tracked": {},
}
if self.has_bias:
@@ -164,13 +153,15 @@ class BatchNormStrategyGenerator(StrategyGenerator):
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def non_split(self):
- name = f'RR = RR x R'
+ name = f"RR = RR x R"
dim_partition_dict_mapping = {
"input": {},
"other": {},
@@ -186,21 +177,19 @@ class BatchNormStrategyGenerator(StrategyGenerator):
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0):
- name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
+ name = f"S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN"
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0]
- },
+ "input": {0: [mesh_dim_0]},
"other": {},
- "output": {
- 0: [mesh_dim_0]
- },
+ "output": {0: [mesh_dim_0]},
"running_mean": {},
"running_var": {},
"num_batches_tracked": {},
@@ -212,33 +201,32 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# set communication action
# For SyncBN case, we don't need to do communication for weight and bias.
- # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
+ # TODO: the communication happens internally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.IMPLICIT)
+ comm_type=CommType.IMPLICIT,
+ )
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
+ name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN"
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
+ "input": {0: [mesh_dim_0, mesh_dim_1]},
"other": {},
- "output": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
+ "output": {0: [mesh_dim_0, mesh_dim_1]},
"running_mean": {},
"running_var": {},
"num_batches_tracked": {},
@@ -250,25 +238,28 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# set communication action
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
- # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
+ # TODO: the communication happens internally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
- comm_type=CommType.IMPLICIT)
+ comm_type=CommType.IMPLICIT,
+ )
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
+ name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN"
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0],
@@ -298,26 +289,29 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# set communication action
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
- # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
+ # TODO: the communication happens internally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0],
- comm_type=CommType.IMPLICIT)
+ comm_type=CommType.IMPLICIT,
+ )
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
def collate_strategies(self) -> List[ShardingStrategy]:
- '''
+ """
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
- '''
+ """
strategy_list = []
# RS = RS x S
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py
index fd7f811c8972412eaec88bb1dcfc639cdf1fe630..c7da0034ec3bf504e52dd4d15f934e5b0c242e05 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py
@@ -14,7 +14,7 @@ from colossalai.tensor.sharding_spec import ShardingSpecException
from .strategy_generator import StrategyGenerator
-__all__ = ['BinaryElementwiseStrategyGenerator']
+__all__ = ["BinaryElementwiseStrategyGenerator"]
class BinaryElementwiseStrategyGenerator(StrategyGenerator):
@@ -26,36 +26,37 @@ class BinaryElementwiseStrategyGenerator(StrategyGenerator):
"""
def validate(self) -> bool:
- assert len(self.op_data) == 3, \
- f'BinaryElementwiseStrategyGenerator only accepts three operation data (input, other and output), but got {len(self.op_data)}'
+ assert (
+ len(self.op_data) == 3
+ ), f"BinaryElementwiseStrategyGenerator only accepts three operation data (input, other and output), but got {len(self.op_data)}"
for name, op_data in self.op_data.items():
if not isinstance(op_data.data, (torch.Tensor, int, float)):
- raise TypeError(f'The operation data {name} is not a torch.Tensor/int/float.')
+ raise TypeError(f"The operation data {name} is not a torch.Tensor/int/float.")
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
- shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
+ shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
# since elementwise ops are not compute-intensive,
# we approximate the backward compute cost
# to be twice the fwd compute cost
fwd_compute_cost = reduce(operator.mul, shape)
bwd_compute_cost = fwd_compute_cost * 2
- compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
- bwd=bwd_compute_cost,
- total=fwd_compute_cost + bwd_compute_cost)
+ compute_cost = TrainCycleItem(
+ fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
+ )
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
# all input, output and outputs have the same shape
- shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
+ strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
# compute fwd memory cost in bytes
# as the elementwise ops are not memory-intensive
- # we approximate the fwd memroy cost to be the output
+ # we approximate the fwd memory cost to be the output
# and the backward memory cost to be grad of input and other
- input_bytes = self._compute_size_in_bytes(strategy, 'input')
- other_bytes = self._compute_size_in_bytes(strategy, 'other')
- output_bytes = self._compute_size_in_bytes(strategy, 'output')
+ input_bytes = self._compute_size_in_bytes(strategy, "input")
+ other_bytes = self._compute_size_in_bytes(strategy, "other")
+ output_bytes = self._compute_size_in_bytes(strategy, "output")
fwd_memory_cost = MemoryCost(activation=output_bytes)
bwd_memory_cost = MemoryCost(activation=input_bytes + other_bytes)
total_memory_cost = MemoryCost(activation=input_bytes + other_bytes + output_bytes)
@@ -66,7 +67,7 @@ class BinaryElementwiseStrategyGenerator(StrategyGenerator):
def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
# we check for the output logical shape to get the number of dimensions
dim_partition_list = []
- dim_size = len(self.op_data['output'].logical_shape)
+ dim_size = len(self.op_data["output"].logical_shape)
# enumerate all the 2D sharding cases
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
@@ -86,21 +87,22 @@ class BinaryElementwiseStrategyGenerator(StrategyGenerator):
# convert these dim partition dict to sharding strategy
for dim_partition_dict in dim_partition_list:
- dim_partition_dict_mapping = dict(input=dim_partition_dict,
- other=dim_partition_dict,
- output=dim_partition_dict)
+ dim_partition_dict_mapping = dict(
+ input=dim_partition_dict, other=dim_partition_dict, output=dim_partition_dict
+ )
try:
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
communication_action_mapping = {}
# get name
- sharding_seq = sharding_spec_mapping['input'].sharding_sequence
- name = f'{sharding_seq} = {sharding_seq} {sharding_seq}'
+ sharding_seq = sharding_spec_mapping["input"].sharding_sequence
+ name = f"{sharding_seq} = {sharding_seq} {sharding_seq}"
sharding_strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(sharding_strategy)
except ShardingSpecException:
continue
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py
index c2154b3104d3d52e994a2add25ddc796792e1c66..5208f61543bb38ebd4f151cad32bf86db581b2f7 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py
@@ -1,11 +1,9 @@
import copy
import operator
-import warnings
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- CommAction,
CommType,
MemoryCost,
ShardingStrategy,
@@ -24,29 +22,32 @@ class ConvStrategyGenerator(StrategyGenerator):
"""
def validate(self) -> bool:
- '''
+ """
In sanity check, we need make sure the input data having correct dimension size.
For Conv1d, the dim of input data should be 3([N, C, L]).
For Conv2d, the dim of input data should be 4([N, C, H, W]).
For Conv3d, the dim of input data should be 5([N, C, H, W, D]).
- '''
- input_op_data = self.op_data['input']
+ """
+ input_op_data = self.op_data["input"]
assert input_op_data.data.dim() in (
- 3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
+ 3,
+ 4,
+ 5,
+ ), f"We suppose the dim of input fed into conv op should in range of [3, 5]."
def update_compute_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the computation cost per device with this specific strategy.
- Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
- '''
- # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
+ Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
+ """
+ # TODO: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
# 1D: (L) * N * Cout * Cin * kernel
# 2D: (H * W) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
- sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
- sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
+ sharded_other_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device()
+ sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
if self.has_bias:
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost = reduce(operator.mul, sharded_output_shape)
@@ -76,14 +77,14 @@ class ConvStrategyGenerator(StrategyGenerator):
def update_memory_cost(self, strategy: ShardingStrategy):
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'other': self._compute_size_in_bytes(strategy, "other"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "other": self._compute_size_in_bytes(strategy, "other"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
- forward_size_mapping['bias'] = bias_size
+ forward_size_mapping["bias"] = bias_size
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
@@ -100,26 +101,20 @@ class ConvStrategyGenerator(StrategyGenerator):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
+ name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}"
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0]
- },
- "other": {
- 1: [mesh_dim_1]
- },
- "output": {
- 0: [mesh_dim_0],
- 1: [mesh_dim_1]
- },
+ "input": {0: [mesh_dim_0]},
+ "other": {1: [mesh_dim_1]},
+ "output": {0: [mesh_dim_0], 1: [mesh_dim_1]},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {0: [mesh_dim_1]}
@@ -132,7 +127,8 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
communication_action_mapping = {"input": input_comm_action}
if self.is_param("other"):
@@ -140,7 +136,8 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
@@ -148,38 +145,41 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
- if self.is_param('bias'):
+ if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- key_for_kwarg='bias')
+ key_for_kwarg="bias",
+ )
communication_action_mapping["bias"] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0):
- name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
+ name = f"S{mesh_dim_0}R = S{mesh_dim_0}R x RR"
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0]
- },
+ "input": {0: [mesh_dim_0]},
"other": {},
"output": {
0: [mesh_dim_0],
@@ -196,7 +196,8 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
@@ -204,42 +205,45 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
- if self.is_param('bias'):
+ if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- key_for_kwarg='bias')
+ key_for_kwarg="bias",
+ )
communication_action_mapping["bias"] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
+ name = f"S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R"
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0],
1: [mesh_dim_1],
},
- "other": {
- 0: [mesh_dim_1]
- },
+ "other": {0: [mesh_dim_1]},
"output": {
0: [mesh_dim_0],
},
@@ -254,7 +258,8 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1,
- comm_type=CommType.AFTER)
+ comm_type=CommType.AFTER,
+ )
communication_action_mapping = {"output": output_comm_action}
@@ -263,7 +268,8 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
@@ -271,7 +277,8 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
if self.is_param("bias"):
@@ -279,23 +286,27 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- key_for_kwarg='bias')
+ key_for_kwarg="bias",
+ )
communication_action_mapping["bias"] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
- name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
+ name = f"RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {
@@ -322,23 +333,27 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.AFTER)
+ comm_type=CommType.AFTER,
+ )
input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
communication_action_mapping = {"output": output_comm_action, "input": input_comm_action}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
- name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
+ name = f"RR = RS{mesh_dim_0} x S{mesh_dim_0}R"
dim_partition_dict_mapping = {
"input": {
@@ -360,17 +375,20 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.AFTER)
+ comm_type=CommType.AFTER,
+ )
communication_action_mapping = {"output": output_comm_action}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_weight_out_channel(self, mesh_dim_0):
- name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
+ name = f"RS{mesh_dim_0} = RR x RS{mesh_dim_0}"
dim_partition_dict_mapping = {
"input": {},
@@ -395,17 +413,20 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
communication_action_mapping = {"input": input_comm_action}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def non_split(self):
- name = f'RR = RR x RR'
+ name = f"RR = RR x RR"
dim_partition_dict_mapping = {
"input": {},
@@ -418,13 +439,13 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping={})
+ return self.get_sharding_strategy(
+ name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={}
+ )
@ignore_sharding_exception
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
+ name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR"
dim_partition_dict_mapping = {
"input": {
@@ -447,14 +468,16 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
communication_action_mapping["other"] = other_comm_action
@@ -464,23 +487,27 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- key_for_kwarg='bias')
+ key_for_kwarg="bias",
+ )
communication_action_mapping["bias"] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
- name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
+ name = f"RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R"
dim_partition_dict_mapping = {
"input": {
1: [mesh_dim_0, mesh_dim_1],
@@ -501,17 +528,20 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
- comm_type=CommType.AFTER)
+ comm_type=CommType.AFTER,
+ )
communication_action_mapping = {"output": output_comm_action}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1):
- name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
+ name = f"RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {},
"other": {
@@ -535,13 +565,16 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
communication_action_mapping = {"input": input_comm_action}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
def collate_strategies(self) -> List[ShardingStrategy]:
strategies = []
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py
index 82a04ab52e739ae3db29efde2a66f30ff24cb8d0..385a8886f2318220530bdc2448bcee8e83a5b8f0 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py
@@ -1,11 +1,9 @@
import copy
import operator
-import warnings
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- CommAction,
CommType,
MemoryCost,
ShardingStrategy,
@@ -27,16 +25,16 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the computation cost per device with this specific strategy.
Note: The computation cost for the embedding handler is estimated as dense computing now.
It may not be accurate.
- '''
+ """
# TODO: estimate the embedding computation cost as sparse operation
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
- sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
- sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
+ sharded_other_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device()
+ sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
input_size_product = reduce(operator.mul, sharded_input_shape)
other_size_product = reduce(operator.mul, sharded_other_shape)
@@ -55,9 +53,9 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
def update_memory_cost(self, strategy: ShardingStrategy):
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'other': self._compute_size_in_bytes(strategy, "other"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "other": self._compute_size_in_bytes(strategy, "other"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
@@ -75,14 +73,15 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def non_split(self):
- name = f'RR = R x RR'
+ name = f"RR = R x RR"
dim_partition_dict_mapping = {
"input": {},
@@ -92,18 +91,16 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping={})
+ return self.get_sharding_strategy(
+ name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={}
+ )
@ignore_sharding_exception
def split_input(self, mesh_dim_0):
- name = f'S{mesh_dim_0}R = S{mesh_dim_0} x RR'
+ name = f"S{mesh_dim_0}R = S{mesh_dim_0} x RR"
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0]
- },
+ "input": {0: [mesh_dim_0]},
"other": {},
"output": {
0: [mesh_dim_0],
@@ -118,7 +115,8 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
@@ -126,17 +124,20 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
communication_action_mapping["other"] = other_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0} x RS{mesh_dim_1}'
+ name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0} x RS{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {
@@ -159,7 +160,8 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
communication_action_mapping = {"input": input_comm_action}
if self.is_param("other"):
@@ -167,7 +169,8 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
@@ -175,22 +178,23 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
communication_action_mapping["other"] = other_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_1d_parallel_on_input(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1} x RR'
+ name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1} x RR"
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
+ "input": {0: [mesh_dim_0, mesh_dim_1]},
"other": {},
"output": {
0: [mesh_dim_0, mesh_dim_1],
@@ -207,7 +211,8 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
@@ -215,17 +220,20 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
communication_action_mapping["other"] = other_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_embedding_dim(self, mesh_dim_0):
- name = f'RS{mesh_dim_0} = R x RS{mesh_dim_0}'
+ name = f"RS{mesh_dim_0} = R x RS{mesh_dim_0}"
dim_partition_dict_mapping = {
"input": {},
@@ -245,17 +253,20 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
communication_action_mapping = {"input": input_comm_action}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_1d_parallel_on_embedding_dim(self, mesh_dim_0, mesh_dim_1):
- name = f'RS{mesh_dim_0}{mesh_dim_1} = R x RS{mesh_dim_0}{mesh_dim_1}'
+ name = f"RS{mesh_dim_0}{mesh_dim_1} = R x RS{mesh_dim_0}{mesh_dim_1}"
dim_partition_dict_mapping = {
"input": {},
@@ -275,13 +286,16 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
communication_action_mapping = {"input": input_comm_action}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
def collate_strategies(self) -> List[ShardingStrategy]:
strategies = []
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py
index bbeb9a639c835869634417e5ceff1a2cad082339..cc8d5771f28e9cb012672713e81eb449e96c944e 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py
@@ -10,7 +10,7 @@ from colossalai.tensor.sharding_spec import ShardingSpecException
from .strategy_generator import StrategyGenerator
-__all__ = ['GetattrGenerator']
+__all__ = ["GetattrGenerator"]
class GetattrGenerator(StrategyGenerator):
@@ -26,10 +26,10 @@ class GetattrGenerator(StrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
- forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")}
+ """
+ forward_size_mapping = {"output": self._compute_size_in_bytes(strategy, "output")}
# compute fwd cost incurred
# fwd_cost = output
@@ -47,7 +47,7 @@ class GetattrGenerator(StrategyGenerator):
def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
# we check for the output logical shape to get the number of dimensions
dim_partition_list = []
- dim_size = len(self.op_data['output'].logical_shape)
+ dim_size = len(self.op_data["output"].logical_shape)
# enumerate all the 2D sharding cases
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
@@ -78,7 +78,8 @@ class GetattrGenerator(StrategyGenerator):
sharding_strategy = self.get_sharding_strategy(
name=name,
sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(sharding_strategy)
except ShardingSpecException:
continue
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py
index 0aeb2e0d4079ea1b302d580554ed7ca24ab7096d..6f01d9cc7f8ef06a0e4b1672a8f98e7db68364ca 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py
@@ -1,19 +1,13 @@
import copy
from typing import List
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- CommType,
- MemoryCost,
- ShardingStrategy,
- TrainCycleItem,
-)
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from colossalai.logging import get_dist_logger
-from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpecException
from .strategy_generator import FollowingStrategyGenerator
-__all__ = ['GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator']
+__all__ = ["GetItemStrategyGenerator", "TensorStrategyGenerator", "TensorTupleStrategyGenerator"]
class GetItemStrategyGenerator(FollowingStrategyGenerator):
@@ -35,12 +29,12 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
+ """
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
@@ -58,27 +52,29 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
class TensorStrategyGenerator(GetItemStrategyGenerator):
- '''
+ """
Deal with case 1 and 2.
- '''
+ """
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
- getitem_index = self.op_data['index'].data
+ getitem_index = self.op_data["index"].data
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
try:
logger = get_dist_logger()
dim_partition_dict_mapping = {}
communication_action_mapping = {}
dim_partition_dict_for_input = copy.deepcopy(
- strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict)
+ strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict
+ )
int_index = False
if isinstance(getitem_index, int):
@@ -120,9 +116,11 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}_{index}'
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
except ShardingSpecException as e:
logger.debug(e)
continue
@@ -137,9 +135,9 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
- '''
+ """
Deal with case 3.
- '''
+ """
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
@@ -158,13 +156,15 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
sharding_spec_mapping["input"] = sharding_spec_for_input
input_sharding_info = f"get the {index} element from ("
for sharding_spec in sharding_spec_for_input:
- input_sharding_info += f'{sharding_spec.sharding_sequence}, '
+ input_sharding_info += f"{sharding_spec.sharding_sequence}, "
input_sharding_info += ")"
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {input_sharding_info}_{strategy_index}'
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py
index fbb6070f7e82c9a41848c626c6271d1a7b9d73ee..e5b7e6f25d4d9007c7e7900632345190336dc160 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py
@@ -18,7 +18,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
-__all__ = ['LayerNormGenerator']
+__all__ = ["LayerNormGenerator"]
class LayerNormGenerator(StrategyGenerator):
@@ -31,21 +31,21 @@ class LayerNormGenerator(StrategyGenerator):
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the computation cost per device with this specific strategy.
- Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
- '''
- # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
+ Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
+ """
+ # TODO: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
# TODO: a constant coefficient need to be added.
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
- sharded_weight_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
+ sharded_weight_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device()
if self.has_bias:
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost = reduce(operator.mul, sharded_weight_shape)
# in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization.
- input_batch_shape = sharded_input_shape[:-len(sharded_weight_shape)]
+ input_batch_shape = sharded_input_shape[: -len(sharded_weight_shape)]
input_batch_product = reduce(operator.mul, input_batch_shape, 1)
norm_kernel_product = reduce(operator.mul, sharded_weight_shape, 1)
forward_compute_cost = input_batch_product * norm_kernel_product
@@ -62,18 +62,18 @@ class LayerNormGenerator(StrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
+ """
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'other': self._compute_size_in_bytes(strategy, "other"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "other": self._compute_size_in_bytes(strategy, "other"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
- forward_size_mapping['bias'] = bias_size
+ forward_size_mapping["bias"] = bias_size
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
@@ -90,8 +90,9 @@ class LayerNormGenerator(StrategyGenerator):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@@ -120,7 +121,8 @@ class LayerNormGenerator(StrategyGenerator):
sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
@@ -128,12 +130,15 @@ class LayerNormGenerator(StrategyGenerator):
sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
communication_action_mapping["bias"] = bias_comm_action
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
return strategy
@@ -155,7 +160,7 @@ class LayerNormGenerator(StrategyGenerator):
@ignore_sharding_exception
def non_split(self):
- name = f'RR = RR x R'
+ name = f"RR = RR x R"
dim_partition_dict_mapping = {
"input": {},
"other": {},
@@ -168,14 +173,16 @@ class LayerNormGenerator(StrategyGenerator):
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
def collate_strategies(self) -> List[ShardingStrategy]:
- '''
+ """
Generate every possible strategies for a LayerNorm node, and record all strategies into the strategies_vector.
- '''
+ """
strategy_list = []
input_data_dim = len(self.op_data["input"].logical_shape)
weight_data_dim = len(self.op_data["other"].logical_shape)
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
index 1ce5a08f2d6b70d20f10476309034ab1a26b75d1..fb182afb917561b6836a18c7abc483d499ab8464 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
@@ -1,5 +1,4 @@
import operator
-from ast import arg
from functools import reduce
from typing import List
@@ -24,14 +23,14 @@ class MatMulStrategyGenerator(StrategyGenerator):
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'other': self._compute_size_in_bytes(strategy, "other"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "other": self._compute_size_in_bytes(strategy, "other"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
if self.has_bias:
bias_size = self._compute_size_in_bytes(strategy, "bias")
- size_mapping['bias'] = bias_size
+ size_mapping["bias"] = bias_size
# compute fwd cost incurred
# fwd_cost = input + other + bias + output
@@ -41,45 +40,47 @@ class MatMulStrategyGenerator(StrategyGenerator):
# compute bwd cost incurred
# bwd_cost = input_grad + bias_grad
- bwd_activation_cost = sum([v for k, v in size_mapping.items() if k in ['input', 'other', 'bias']])
+ bwd_activation_cost = sum([v for k, v in size_mapping.items() if k in ["input", "other", "bias"]])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + 0)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + 0
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
class DotProductStrategyGenerator(MatMulStrategyGenerator):
-
def validate(self) -> bool:
- input_op_data = self.op_data['input']
- other_op_data = self.op_data['other']
+ input_op_data = self.op_data["input"]
+ other_op_data = self.op_data["other"]
assert input_op_data.data.dim() == 1 and other_op_data.data.dim() == 1
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = fwd_compute_cost * 2
- compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
- bwd=bwd_compute_cost,
- total=fwd_compute_cost + bwd_compute_cost)
+ compute_cost = TrainCycleItem(
+ fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
+ )
return compute_cost
@ignore_sharding_exception
def no_split(self):
- name = f'R = R dot R'
- dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}}
+ name = f"R = R dot R"
+ dim_partition_dict = {"input": {}, "other": {}, "output": {}, "bias": {}}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_one_dim(self, mesh_dim):
- name = f'R = S{mesh_dim} dot S{mesh_dim}'
+ name = f"R = S{mesh_dim} dot S{mesh_dim}"
# get sharding spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}, "bias": {0: [mesh_dim]}}
@@ -87,14 +88,17 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
# get communication action
output_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['output'],
+ sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim,
- comm_type=CommType.AFTER)
+ comm_type=CommType.AFTER,
+ )
communication_action_mapping = {"output": output_comm_action}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
@@ -112,19 +116,18 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
class MatVecStrategyGenerator(MatMulStrategyGenerator):
-
def validate(self) -> bool:
- input_op_data = self.op_data['input']
- other_op_data = self.op_data['other']
+ input_op_data = self.op_data["input"]
+ other_op_data = self.op_data["other"]
assert input_op_data.data.dim() == 2 and other_op_data.data.dim() == 1
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = fwd_compute_cost * 2
- compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
- bwd=bwd_compute_cost,
- total=fwd_compute_cost + bwd_compute_cost)
+ compute_cost = TrainCycleItem(
+ fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
+ )
return compute_cost
@ignore_sharding_exception
@@ -133,67 +136,69 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
if self.has_bias:
- dim_partition_dict['bias'] = {}
+ dim_partition_dict["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping={})
+ return self.get_sharding_strategy(
+ name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={}
+ )
@ignore_sharding_exception
def split_input_batch(self, mesh_dim):
- name = f'S{mesh_dim}R = S{mesh_dim}R x R'
+ name = f"S{mesh_dim}R = S{mesh_dim}R x R"
# get sharding spec
dim_partition_dict = {
- "input": {
- 0: [mesh_dim]
- },
+ "input": {0: [mesh_dim]},
"other": {},
- "output": {
- 0: [mesh_dim]
- },
+ "output": {0: [mesh_dim]},
}
if self.has_bias:
- dim_partition_dict['bias'] = {}
+ dim_partition_dict["bias"] = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)
# get communication action
communication_action_mapping = {}
- if self.is_param('other'):
+ if self.is_param("other"):
other_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['other'],
+ sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['other'],
+ sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
- arg_index=1)
- communication_action_mapping['other'] = other_comm_action
+ arg_index=1,
+ )
+ communication_action_mapping["other"] = other_comm_action
if self.has_bias:
- if self.is_param('bias'):
+ if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
- arg_index=2)
- communication_action_mapping['bias'] = bias_comm_action
+ arg_index=2,
+ )
+ communication_action_mapping["bias"] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
@@ -209,12 +214,13 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
-
- def __init__(self,
- operation_data_mapping,
- device_mesh,
- linear_projection_type='linear',
- solver_perference=SolverPerference.STANDARD):
+ def __init__(
+ self,
+ operation_data_mapping,
+ device_mesh,
+ linear_projection_type="linear",
+ solver_perference=SolverPerference.STANDARD,
+ ):
super().__init__(operation_data_mapping, device_mesh)
self.linear_projection_type = linear_projection_type
self.solver_perference = solver_perference
@@ -224,17 +230,17 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# C: [M, N], A: [M, P], B: [P, N]
# fwd cost = MNP (only count mul)
# bwd: 2 x fwd_cost
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
- sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
+ sharded_other_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device()
dim_m_val = reduce(operator.mul, sharded_input_shape[:-1])
dim_n_val = sharded_other_shape[-1]
dim_p_val = sharded_other_shape[0]
fwd_compute_cost = dim_m_val * dim_n_val * dim_p_val
bwd_compute_cost = fwd_compute_cost * 2
- compute_cost = TrainCycleItem(fwd=bwd_compute_cost,
- bwd=bwd_compute_cost,
- total=fwd_compute_cost + bwd_compute_cost)
+ compute_cost = TrainCycleItem(
+ fwd=bwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
+ )
strategy.compute_cost = compute_cost
def dp_strategies(self) -> List[ShardingStrategy]:
@@ -301,28 +307,21 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
@ignore_sharding_exception
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS
- name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
+ name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}"
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0]
- },
- "other": {
- -1: [mesh_dim_1]
- },
- "output": {
- 0: [mesh_dim_0],
- -1: [mesh_dim_1]
- },
+ "input": {0: [mesh_dim_0]},
+ "other": {-1: [mesh_dim_1]},
+ "output": {0: [mesh_dim_0], -1: [mesh_dim_1]},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
- if self.linear_projection_type == 'linear':
- dim_partition_dict_mapping['bias'] = {-1: [mesh_dim_1]}
- elif self.linear_projection_type == 'addmm':
- dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0], -1: [mesh_dim_1]}
+ if self.linear_projection_type == "linear":
+ dim_partition_dict_mapping["bias"] = {-1: [mesh_dim_1]}
+ elif self.linear_projection_type == "addmm":
+ dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0], -1: [mesh_dim_1]}
else:
- raise ('Unsupported linear projection type')
+ raise ("Unsupported linear projection type")
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
@@ -333,75 +332,75 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
- if self.is_param('other'):
+ if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
- communication_action_mapping['input'] = input_comm_action
- communication_action_mapping['other'] = other_comm_action
+ communication_action_mapping["input"] = input_comm_action
+ communication_action_mapping["other"] = other_comm_action
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
- if self.has_bias and self.linear_projection_type == 'linear':
- if self.is_param('bias'):
+ if self.has_bias and self.linear_projection_type == "linear":
+ if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- key_for_kwarg='bias')
- communication_action_mapping['bias'] = bias_comm_action
+ key_for_kwarg="bias",
+ )
+ communication_action_mapping["bias"] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
# handle the case SR = SS x SR
- name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
+ name = f"S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R"
# get sharding spec mapping
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0],
- -1: [mesh_dim_1]
- },
- "other": {
- 0: [mesh_dim_1]
- },
+ "input": {0: [mesh_dim_0], -1: [mesh_dim_1]},
+ "other": {0: [mesh_dim_1]},
"bias": {},
- "output": {
- 0: [mesh_dim_0]
- },
+ "output": {0: [mesh_dim_0]},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
- if self.linear_projection_type == 'linear':
- dim_partition_dict_mapping['bias'] = {}
- elif self.linear_projection_type == 'addmm':
- dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0]}
+ if self.linear_projection_type == "linear":
+ dim_partition_dict_mapping["bias"] = {}
+ elif self.linear_projection_type == "addmm":
+ dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0]}
else:
- raise ('Unsupported linear projection type')
+ raise ("Unsupported linear projection type")
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
@@ -412,66 +411,64 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1,
- comm_type=CommType.AFTER)
+ comm_type=CommType.AFTER,
+ )
- if self.is_param('other'):
+ if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=1)
+ arg_index=1,
+ )
- communication_action_mapping['other'] = other_comm_action
- communication_action_mapping['output'] = output_comm_action
+ communication_action_mapping["other"] = other_comm_action
+ communication_action_mapping["output"] = output_comm_action
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
- if self.has_bias and self.linear_projection_type == 'linear':
- if self.is_param('bias'):
+ if self.has_bias and self.linear_projection_type == "linear":
+ if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- key_for_kwarg='bias')
- communication_action_mapping['bias'] = bias_comm_action
+ key_for_kwarg="bias",
+ )
+ communication_action_mapping["bias"] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
- name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
+ name = f"RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}"
# get sharding specs
dim_partition_dict_mapping = {
- "input": {
- -1: [mesh_dim_0]
- },
- "other": {
- 0: [mesh_dim_0],
- -1: [mesh_dim_1]
- },
- "bias": {
- -1: [mesh_dim_1]
- },
- "output": {
- -1: [mesh_dim_1]
- },
+ "input": {-1: [mesh_dim_0]},
+ "other": {0: [mesh_dim_0], -1: [mesh_dim_1]},
+ "bias": {-1: [mesh_dim_1]},
+ "output": {-1: [mesh_dim_1]},
}
# We don't have to do anything special for bias here, because
@@ -482,34 +479,34 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['output'],
+ sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.AFTER)
+ comm_type=CommType.AFTER,
+ )
input_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['input'],
+ sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
communication_action_mapping["input"] = input_comm_action
- communication_action_mapping['output'] = output_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ communication_action_mapping["output"] = output_comm_action
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def recompute_split_both_contract(self, mesh_dim):
- name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
+ name = f"RR = RS{mesh_dim} x S{mesh_dim}R"
# get sharding spec
dim_partition_dict_mapping = {
- "input": {
- -1: [mesh_dim]
- },
- "other": {
- 0: [mesh_dim]
- },
+ "input": {-1: [mesh_dim]},
+ "other": {0: [mesh_dim]},
"bias": {},
"output": {},
}
@@ -520,32 +517,29 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['output'],
+ sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim,
- comm_type=CommType.AFTER)
+ comm_type=CommType.AFTER,
+ )
- communication_action_mapping['output'] = output_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ communication_action_mapping["output"] = output_comm_action
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_rhs_space_only(self, mesh_dim):
- name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
+ name = f"RS{mesh_dim} = RR x RS{mesh_dim}"
# get sharding spec
dim_partition_dict_mapping = {
"input": {},
- "other": {
- -1: [mesh_dim]
- },
- "bias": {
- -1: [mesh_dim]
- },
- "output": {
- -1: [mesh_dim]
- },
+ "other": {-1: [mesh_dim]},
+ "bias": {-1: [mesh_dim]},
+ "output": {-1: [mesh_dim]},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
@@ -554,93 +548,94 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
communication_action_mapping = {}
input_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['input'],
+ sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
- communication_action_mapping['input'] = input_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ communication_action_mapping["input"] = input_comm_action
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
+ name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR"
# get sharding spec
dim_partition_dict_mapping = {
- "input": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
+ "input": {0: [mesh_dim_0, mesh_dim_1]},
"other": {},
"bias": {},
- "output": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
+ "output": {0: [mesh_dim_0, mesh_dim_1]},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
- if self.linear_projection_type == 'linear':
- dim_partition_dict_mapping['bias'] = {}
- elif self.linear_projection_type == 'addmm':
- dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0, mesh_dim_1]}
+ if self.linear_projection_type == "linear":
+ dim_partition_dict_mapping["bias"] = {}
+ elif self.linear_projection_type == "addmm":
+ dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0, mesh_dim_1]}
else:
- raise ('Unsupported linear projection type')
+ raise ("Unsupported linear projection type")
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
communication_action_mapping = {}
- if self.is_param('other'):
+ if self.is_param("other"):
other_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['other'],
+ sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
other_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['other'],
+ sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- arg_index=1)
- communication_action_mapping['other'] = other_comm_action
+ arg_index=1,
+ )
+ communication_action_mapping["other"] = other_comm_action
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
- if self.has_bias and self.linear_projection_type == 'linear':
- if self.is_param('bias'):
+ if self.has_bias and self.linear_projection_type == "linear":
+ if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
- comm_type=CommType.HOOK)
+ comm_type=CommType.HOOK,
+ )
else:
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- key_for_kwarg='bias')
- communication_action_mapping['bias'] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ key_for_kwarg="bias",
+ )
+ communication_action_mapping["bias"] = bias_comm_action
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
- name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
+ name = f"RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R"
# get sharding spec
dim_partition_dict_mapping = {
- "input": {
- -1: [mesh_dim_0, mesh_dim_1]
- },
- "other": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
+ "input": {-1: [mesh_dim_0, mesh_dim_1]},
+ "other": {0: [mesh_dim_0, mesh_dim_1]},
"bias": {},
"output": {},
}
@@ -652,32 +647,29 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['output'],
+ sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
- comm_type=CommType.AFTER)
- communication_action_mapping['output'] = output_comm_action
+ comm_type=CommType.AFTER,
+ )
+ communication_action_mapping["output"] = output_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
- name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
+ name = f"RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}"
# get sharding spec
dim_partition_dict_mapping = {
"input": {},
- "other": {
- -1: [mesh_dim_0, mesh_dim_1]
- },
- "bias": {
- -1: [mesh_dim_0, mesh_dim_1]
- },
- "output": {
- -1: [mesh_dim_0, mesh_dim_1]
- },
+ "other": {-1: [mesh_dim_0, mesh_dim_1]},
+ "bias": {-1: [mesh_dim_0, mesh_dim_1]},
+ "output": {-1: [mesh_dim_0, mesh_dim_1]},
}
# We don't have to do anything special for bias here, because
@@ -687,20 +679,23 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action
communication_action_mapping = {}
input_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['input'],
+ sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- arg_index=0)
- communication_action_mapping['input'] = input_comm_action
+ arg_index=0,
+ )
+ communication_action_mapping["input"] = input_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def non_split(self):
- name = f'RR = RR x RR'
+ name = f"RR = RR x RR"
# get sharding spec
dim_partition_dict_mapping = {
@@ -717,22 +712,24 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action
communication_action_mapping = {}
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
def validate(self) -> bool:
assert "input" in self.op_data
assert "other" in self.op_data
# make sure the other has 2 dim
- input_data = self.op_data['input']
- other_data = self.op_data['other']
+ input_data = self.op_data["input"]
+ other_data = self.op_data["other"]
assert input_data.data.dim() > 0 and other_data.data.dim() == 2
assert other_data.logical_shape[0] == input_data.logical_shape[-1]
if self.has_bias:
- bias_data = self.op_data['bias']
+ bias_data = self.op_data["bias"]
assert bias_data.logical_shape[-1] == other_data.logical_shape[-1]
@@ -757,37 +754,38 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
def _pop_batch_dim_sharding_for_output(self, dim_partition_dict):
# remove partition dict for dim 0
- dim_partition_dict['output'].pop(0, None)
+ dim_partition_dict["output"].pop(0, None)
# decrease the remaining dim index by 1
temp_dim_partition = {}
- keys = list(dim_partition_dict['output'].keys())
+ keys = list(dim_partition_dict["output"].keys())
for key in keys:
- val = dim_partition_dict['output'].pop(key)
+ val = dim_partition_dict["output"].pop(key)
temp_dim_partition[key - 1] = val
- dim_partition_dict['output'].update(temp_dim_partition)
+ dim_partition_dict["output"].update(temp_dim_partition)
def validate(self) -> bool:
- input_op_data = self.op_data['input']
- other_op_data = self.op_data['other']
+ input_op_data = self.op_data["input"]
+ other_op_data = self.op_data["other"]
assert len(input_op_data.logical_shape) == 3 or len(other_op_data.logical_shape) == 3
- if 'bias' in self.op_data:
- bias_op_data = self.op_data['bias']
+ if "bias" in self.op_data:
+ bias_op_data = self.op_data["bias"]
assert bias_op_data.data.dim() < 3 and len(bias_op_data.logical_shape) == 2
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
- fwd_compute_cost = self.op_data['input'].data.shape[-1] * reduce(operator.mul,
- self.op_data['output'].data.shape)
+ fwd_compute_cost = self.op_data["input"].data.shape[-1] * reduce(
+ operator.mul, self.op_data["output"].data.shape
+ )
bwd_compute_cost = fwd_compute_cost * 2
- compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
- bwd=bwd_compute_cost,
- total=fwd_compute_cost + bwd_compute_cost)
+ compute_cost = TrainCycleItem(
+ fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
+ )
strategy.compute_cost = compute_cost
@ignore_sharding_exception
def split_one_batch_dim(self, mesh_dim):
- name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
+ name = f"Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}"
# get sharding_spec
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}}
@@ -799,30 +797,27 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
communication_action_mapping = {}
if self.has_bias:
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim,
comm_type=CommType.BEFORE,
- arg_index=0)
- communication_action_mapping['bias'] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ arg_index=0,
+ )
+ communication_action_mapping["bias"] = bias_comm_action
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
- name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}'
+ name = f"Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}"
dim_partition_dict = {
- "input": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
- "other": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
+ "input": {0: [mesh_dim_0, mesh_dim_1]},
+ "other": {0: [mesh_dim_0, mesh_dim_1]},
"bias": {},
- "output": {
- 0: [mesh_dim_0, mesh_dim_1]
- }
+ "output": {0: [mesh_dim_0, mesh_dim_1]},
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
@@ -832,35 +827,28 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
communication_action_mapping = {}
if self.has_bias:
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- arg_index=0)
- communication_action_mapping['bias'] = bias_comm_action
+ arg_index=0,
+ )
+ communication_action_mapping["bias"] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1):
- name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}'
+ name = f"Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}"
dim_partition_dict = {
- "input": {
- 0: [mesh_dim_0],
- 1: [mesh_dim_1]
- },
- "other": {
- 0: [mesh_dim_0]
- },
- "bias": {
- 0: [mesh_dim_1]
- },
- "output": {
- 0: [mesh_dim_0],
- 1: [mesh_dim_1]
- }
+ "input": {0: [mesh_dim_0], 1: [mesh_dim_1]},
+ "other": {0: [mesh_dim_0]},
+ "bias": {0: [mesh_dim_1]},
+ "output": {0: [mesh_dim_0], 1: [mesh_dim_1]},
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
@@ -869,46 +857,40 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
communication_action_mapping = {}
other_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['other'],
+ sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
- arg_index=1)
- communication_action_mapping['other'] = other_comm_action
+ arg_index=1,
+ )
+ communication_action_mapping["other"] = other_comm_action
if self.has_bias:
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
- arg_index=0)
- communication_action_mapping['bias'] = bias_comm_action
+ arg_index=0,
+ )
+ communication_action_mapping["bias"] = bias_comm_action
# for addbmm case, other is the third argument instead of second.
- communication_action_mapping['other'].arg_index += 1
+ communication_action_mapping["other"].arg_index += 1
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1):
- name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}'
+ name = f"Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}"
dim_partition_dict = {
- "input": {
- 0: [mesh_dim_0]
- },
- "other": {
- 0: [mesh_dim_0],
- 2: [mesh_dim_1]
- },
- "bias": {
- 1: [mesh_dim_1]
- },
- "output": {
- 0: [mesh_dim_0],
- 2: [mesh_dim_1]
- }
+ "input": {0: [mesh_dim_0]},
+ "other": {0: [mesh_dim_0], 2: [mesh_dim_1]},
+ "bias": {1: [mesh_dim_1]},
+ "output": {0: [mesh_dim_0], 2: [mesh_dim_1]},
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
@@ -917,43 +899,41 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
communication_action_mapping = {}
input_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['input'],
+ sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
- arg_index=0)
- communication_action_mapping['input'] = input_comm_action
+ arg_index=0,
+ )
+ communication_action_mapping["input"] = input_comm_action
if self.has_bias:
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
- comm_type=CommType.BEFORE)
- communication_action_mapping['bias'] = bias_comm_action
+ comm_type=CommType.BEFORE,
+ )
+ communication_action_mapping["bias"] = bias_comm_action
# for addbmm case, other is the second argument instead of first.
- communication_action_mapping['input'].arg_index += 1
+ communication_action_mapping["input"].arg_index += 1
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
@ignore_sharding_exception
def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1):
- name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}'
+ name = f"Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}"
dim_partition_dict = {
- "input": {
- 0: [mesh_dim_0],
- 2: [mesh_dim_1]
- },
- "other": {
- 0: [mesh_dim_0],
- 1: [mesh_dim_1]
- },
+ "input": {0: [mesh_dim_0], 2: [mesh_dim_1]},
+ "other": {0: [mesh_dim_0], 1: [mesh_dim_1]},
"bias": {},
"output": {
0: [mesh_dim_0],
- }
+ },
}
if self.squeeze_batch_dim:
self._pop_batch_dim_sharding_for_output(dim_partition_dict)
@@ -962,29 +942,33 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
communication_action_mapping = {}
output_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['output'],
+ sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1,
- comm_type=CommType.AFTER)
- communication_action_mapping['output'] = output_comm_action
+ comm_type=CommType.AFTER,
+ )
+ communication_action_mapping["output"] = output_comm_action
if self.has_bias:
bias_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping['bias'],
+ sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
- arg_index=0)
- communication_action_mapping['bias'] = bias_comm_action
+ arg_index=0,
+ )
+ communication_action_mapping["bias"] = bias_comm_action
- return self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ return self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
device_mesh_is_1d = True
- if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape:
+ if len(self.device_mesh.shape) == 2 and 1 not in self.device_mesh.shape:
device_mesh_is_1d = False
if device_mesh_is_1d:
@@ -992,10 +976,10 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# Sb = Sb x Sb
# can be None as it is only for 1D device mesh
# only for 1D device mesh
- if len(self.device_mesh.mesh_shape) == 1:
+ if len(self.device_mesh.shape) == 1:
mesh_dim = 0
else:
- mesh_dim = self.device_mesh.mesh_shape.index(1)
+ mesh_dim = self.device_mesh.shape.index(1)
strategy_list.append(self.split_one_batch_dim(mesh_dim))
else:
# for 2D device mesh
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py
index 9df6d2fbfa127b71eba66256fb27f204fb1da5fe..b307e38b5b6d15bd760b9fc03d0bbe010106b6df 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py
@@ -17,32 +17,35 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
"""
NormalPoolStrategyGenerator is a generic class to generate strategies for pool operation like MaxPoolxd.
The reason we call this normal pool is AvgPoolxd and MaxPoolxd are taking the kernel size element from image,
- and reduce them depening on the operation type.
+ and reduce them depending on the operation type.
"""
def validate(self) -> bool:
- '''
+ """
In sanity check, we need make sure the input data having correct dimension size.
For Pool1d, the dim of input data should be 3([N, C, L]).
For Pool2d, the dim of input data should be 4([N, C, H, W]).
For Pool3d, the dim of input data should be 5([N, C, H, W, D]).
- '''
- input_op_data = self.op_data['input']
+ """
+ input_op_data = self.op_data["input"]
assert input_op_data.data.dim() in (
- 3, 4, 5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].'
+ 3,
+ 4,
+ 5,
+ ), f"We suppose the dim of input fed into Pool op should in range of [3, 5]."
def update_compute_cost(self, strategy: ShardingStrategy) -> TrainCycleItem:
- '''
+ """
Compute the computation cost per device with this specific strategy.
- Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
- '''
- # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
+ Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
+ """
+ # TODO: compute_cost need to be divided by TFLOPS, now it just shows the computation size.
# 1D: (Lout) * N * C * kernel
# 2D: (H * W) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel
- sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
+ sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
kernel_size = self.op_data["other"].data
if isinstance(kernel_size, int):
@@ -61,8 +64,8 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
@@ -88,12 +91,16 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
- name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
+ name = (
+ f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
+ )
communication_action_mapping = {}
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
return strategy
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py
index 69d1642d4f808038d0eeb58547a7b1c0604c85eb..33fb1ac5c5be779a5f65b44431e4ba84c878fab9 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py
@@ -12,7 +12,7 @@ from colossalai.device.device_mesh import DeviceMesh
from .strategy_generator import OutputStrategyGenerator
-__all__ = ['OutputGenerator']
+__all__ = ["OutputGenerator"]
class OutputGenerator(OutputStrategyGenerator):
@@ -20,8 +20,13 @@ class OutputGenerator(OutputStrategyGenerator):
OutputGenerator is a generic class to generate strategies for Output Node.
"""
- def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
- predecessor_nodes: List[Node], output_option: str):
+ def __init__(
+ self,
+ operation_data_mapping: Dict[str, OperationData],
+ device_mesh: DeviceMesh,
+ predecessor_nodes: List[Node],
+ output_option: str,
+ ):
super().__init__(operation_data_mapping, device_mesh, predecessor_nodes)
self.output_option = output_option
@@ -33,9 +38,9 @@ class OutputGenerator(OutputStrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
+ """
fwd_mem_cost = MemoryCost(activation=0, parameter=0)
bwd_mem_cost = MemoryCost(activation=0, parameter=0)
@@ -65,16 +70,18 @@ class OutputGenerator(OutputStrategyGenerator):
else:
dim_partition_dict_for_output = tuple(dim_partition_dict_for_output)
- dim_partition_dict_mapping['output'] = dim_partition_dict_for_output
+ dim_partition_dict_mapping["output"] = dim_partition_dict_for_output
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
- name = 'Replica Output'
+ name = "Replica Output"
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
return strategy
def distributed_strategy(self, mesh_list: List[List[int]] = None) -> List[ShardingStrategy]:
@@ -82,19 +89,15 @@ class OutputGenerator(OutputStrategyGenerator):
Generate distributed strategy for output node.
"""
# TODO: need to take care of the case when the first element of output only need to be sharded.
- output_op_data = self.op_data['output']
+ output_op_data = self.op_data["output"]
if isinstance(output_op_data.data, tuple):
length = len(output_op_data.data)
dim_partition_dict_mapping = {
- "output": [{
- 0: mesh_list
- }] * length,
+ "output": [{0: mesh_list}] * length,
}
else:
dim_partition_dict_mapping = {
- "output": {
- 0: mesh_list
- },
+ "output": {0: mesh_list},
}
for index, _ in enumerate(self.predecessor_nodes):
mapping_name = f"input_{index}"
@@ -103,19 +106,21 @@ class OutputGenerator(OutputStrategyGenerator):
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
- name = 'Distributed Output'
+ name = "Distributed Output"
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
return strategy
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
mesh_list = [0, 1]
- if self.output_option == 'replicated':
+ if self.output_option == "replicated":
strategy_list.append(self.replica_strategy())
- elif self.output_option == 'distributed':
+ elif self.output_option == "distributed":
strategy_list.append(self.distributed_strategy(mesh_list))
return strategy_list
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py
index 779a7ced93bb503c390bd89382d087230e48d2f0..df0862a396d2553ed7939dd09bea45d2317df211 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py
@@ -10,7 +10,7 @@ from colossalai.device.device_mesh import DeviceMesh
from .strategy_generator import StrategyGenerator
-__all__ = ['PlaceholderGenerator']
+__all__ = ["PlaceholderGenerator"]
class PlaceholderGenerator(StrategyGenerator):
@@ -18,8 +18,9 @@ class PlaceholderGenerator(StrategyGenerator):
PlaceholderGenerator is a generic class to generate strategies for placeholder node.
"""
- def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
- placeholder_option: str):
+ def __init__(
+ self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, placeholder_option: str
+ ):
super().__init__(operation_data_mapping, device_mesh)
self.placeholder_option = placeholder_option
@@ -31,10 +32,10 @@ class PlaceholderGenerator(StrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
- forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")}
+ """
+ forward_size_mapping = {"output": self._compute_size_in_bytes(strategy, "output")}
# compute fwd cost incurred
# fwd_cost = output
@@ -58,11 +59,13 @@ class PlaceholderGenerator(StrategyGenerator):
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
- name = 'Replica Placeholder'
+ name = "Replica Placeholder"
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
return strategy
@@ -71,29 +74,31 @@ class PlaceholderGenerator(StrategyGenerator):
Generate distributed strategy for placeholder node.
"""
dim_partition_dict_mapping = {
- "output": {
- 0: mesh_list
- },
+ "output": {0: mesh_list},
}
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
- name = 'Distributed Placeholder'
+ name = "Distributed Placeholder"
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
return strategy
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
- if self.placeholder_option == 'distributed':
+ if self.placeholder_option == "distributed":
mesh_list = [0, 1]
distributed_strategy = self.distributed_placeholder(mesh_list)
strategy_list.append(distributed_strategy)
else:
- assert self.placeholder_option == 'replicated', f'placeholder_option {self.placeholder_option} is not supported'
+ assert (
+ self.placeholder_option == "replicated"
+ ), f"placeholder_option {self.placeholder_option} is not supported"
replicated_strategy = self.replica_placeholder()
strategy_list.append(replicated_strategy)
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py
index 24f75e352935f149e02c399b2c8e90c0f3ddc2f7..48f454553ac745fc5bfc84a59def4db7790897f0 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py
@@ -17,7 +17,7 @@ from colossalai.auto_parallel.tensor_shard.utils import (
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
-__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator', 'SplitGenerator']
+__all__ = ["ReshapeGenerator", "ViewGenerator", "PermuteGenerator", "TransposeGenerator", "SplitGenerator"]
class ReshapeGenerator(FollowingStrategyGenerator):
@@ -33,12 +33,12 @@ class ReshapeGenerator(FollowingStrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
+ """
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
@@ -56,8 +56,9 @@ class ReshapeGenerator(FollowingStrategyGenerator):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@@ -77,8 +78,8 @@ class ViewGenerator(ReshapeGenerator):
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
- origin_shape = self.op_data['input'].data.shape
- tgt_shape = self.op_data['tgt_shape'].data
+ origin_shape = self.op_data["input"].data.shape
+ tgt_shape = self.op_data["tgt_shape"].data
reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
@@ -86,8 +87,9 @@ class ViewGenerator(ReshapeGenerator):
keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict)
if keep_sharding_status:
- dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input,
- reshape_mapping_dict)
+ dim_partition_dict_for_output = infer_output_dim_partition_dict(
+ dim_partition_dict_for_input, reshape_mapping_dict
+ )
else:
dim_partition_dict_for_output = {}
@@ -119,7 +121,8 @@ class ViewGenerator(ReshapeGenerator):
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
# it will gather the input through gather_dim during forward phase.
input_comm_action.comm_spec.gather_dim = shard_dim
# it will split the input activation grad through shard_dim during backward phase.
@@ -127,10 +130,10 @@ class ViewGenerator(ReshapeGenerator):
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]
- target_spec = ShardingSpec(device_mesh=self.device_mesh,
- entire_shape=source_spec.entire_shape,
- dim_partition_dict={})
- comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
+ target_spec = ShardingSpec(
+ device_mesh=self.device_mesh, entire_shape=source_spec.entire_shape, dim_partition_dict={}
+ )
+ comm_spec = {"src_spec": source_spec, "tgt_spec": target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
@@ -139,9 +142,11 @@ class ViewGenerator(ReshapeGenerator):
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
@@ -159,7 +164,7 @@ class PermuteGenerator(ReshapeGenerator):
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
- permute_dims = self.op_data['permute_dims'].data
+ permute_dims = self.op_data["permute_dims"].data
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
for dim_index, permute_dim in enumerate(permute_dims):
@@ -177,9 +182,11 @@ class PermuteGenerator(ReshapeGenerator):
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
@@ -199,7 +206,7 @@ class TransposeGenerator(ReshapeGenerator):
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
- transpose_dims = self.op_data['transpose_dims'].data
+ transpose_dims = self.op_data["transpose_dims"].data
dim_0 = transpose_dims[0]
dim_1 = transpose_dims[1]
for dim, sharded_dims in dim_partition_dict_for_input.items():
@@ -221,9 +228,11 @@ class TransposeGenerator(ReshapeGenerator):
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
@@ -242,7 +251,7 @@ class SplitGenerator(ReshapeGenerator):
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
- split_size, split_dim = self.op_data['split_info'].data
+ split_size, split_dim = self.op_data["split_info"].data
if split_dim in dim_partition_dict_for_input:
recover_dims = dim_partition_dict_for_input.pop(split_dim)
@@ -271,7 +280,8 @@ class SplitGenerator(ReshapeGenerator):
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=recover_dims,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
# it will gather the input through gather_dim during forward phase.
input_comm_action.comm_spec.gather_dim = split_dim
# it will split the input activation grad through split_dim during backward phase.
@@ -282,7 +292,7 @@ class SplitGenerator(ReshapeGenerator):
source_spec = input_sharding_spec
# target sharding spec
target_spec = sharding_spec_mapping["input"]
- comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
+ comm_spec = {"src_spec": source_spec, "tgt_spec": target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
@@ -291,9 +301,11 @@ class SplitGenerator(ReshapeGenerator):
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
@@ -341,16 +353,17 @@ class DefaultReshapeGenerator(ReshapeGenerator):
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.BEFORE,
- arg_index=0)
+ arg_index=0,
+ )
input_comm_action.comm_spec.gather_dim = total_mesh_dim_list
input_comm_action.comm_spec.shard_dim = total_mesh_dim_list
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]
- target_spec = ShardingSpec(device_mesh=self.device_mesh,
- entire_shape=source_spec.entire_shape,
- dim_partition_dict={})
- comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
+ target_spec = ShardingSpec(
+ device_mesh=self.device_mesh, entire_shape=source_spec.entire_shape, dim_partition_dict={}
+ )
+ comm_spec = {"src_spec": source_spec, "tgt_spec": target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
@@ -358,9 +371,11 @@ class DefaultReshapeGenerator(ReshapeGenerator):
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py
index a1ebadd043e2c2e563fcdc611567bca3ededfa51..d4382f9941d2b21475ef7d077a0af92054aae9b0 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py
@@ -4,21 +4,9 @@ from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- CommAction,
- CommType,
- MemoryCost,
- ShardingStrategy,
- TrainCycleItem,
-)
-from colossalai.auto_parallel.tensor_shard.utils import (
- check_keep_sharding_status,
- detect_reshape_mapping,
- infer_output_dim_partition_dict,
-)
-from colossalai.tensor.shape_consistency import CollectiveCommPattern
-
-__all__ = ['SoftmaxGenerator']
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
+
+__all__ = ["SoftmaxGenerator"]
class SoftmaxGenerator(FollowingStrategyGenerator):
@@ -30,11 +18,11 @@ class SoftmaxGenerator(FollowingStrategyGenerator):
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the computation cost per device with this specific strategy.
- '''
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
- sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
+ """
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
+ sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
input_size_product = reduce(operator.mul, sharded_input_shape)
output_size_product = reduce(operator.mul, sharded_output_shape)
@@ -45,12 +33,12 @@ class SoftmaxGenerator(FollowingStrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
+ """
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
@@ -68,8 +56,9 @@ class SoftmaxGenerator(FollowingStrategyGenerator):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@@ -80,10 +69,10 @@ class SoftmaxGenerator(FollowingStrategyGenerator):
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
- softmax_dim = self.op_data['softmax_dim'].data
+ softmax_dim = self.op_data["softmax_dim"].data
if softmax_dim in dim_partition_dict_for_input:
- recover_dims = dim_partition_dict_for_input.pop(softmax_dim)
+ dim_partition_dict_for_input.pop(softmax_dim)
dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input)
dim_partition_dict_mapping = {
@@ -96,9 +85,11 @@ class SoftmaxGenerator(FollowingStrategyGenerator):
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
index 6d68521aaea7989c085c24f32a8dc92f4b1b71fc..7bf2c8cc12a39730ea79dccb2ef810b32e0b95a9 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
@@ -39,7 +39,7 @@ class StrategyGenerator(ABC):
"""
A utility method to check for the existence of bias operand for convenience.
"""
- return 'bias' in self.op_data
+ return "bias" in self.op_data
def is_param(self, op_data_name):
other_data = self.op_data[op_data_name]
@@ -49,8 +49,12 @@ class StrategyGenerator(ABC):
other_data = self.op_data[op_data_name]
return other_data.type == OperationDataType.BUFFER
- def get_sharding_strategy(self, name: str, sharding_spec_mapping: Dict[str, ShardingSpec],
- communication_action_mapping: Dict[str, CommSpec]):
+ def get_sharding_strategy(
+ self,
+ name: str,
+ sharding_spec_mapping: Dict[str, ShardingSpec],
+ communication_action_mapping: Dict[str, CommSpec],
+ ):
"""
A factory method to produce a ShardingStrategy object.
@@ -80,24 +84,28 @@ class StrategyGenerator(ABC):
op_data = self.op_data[op_data_name]
def _to_sharding_spec(
- data: any, logical_shape: any,
- dim_partition_dict: Dict[int, List[int]]) -> Union[ShardingSpec, List[ShardingSpec], None]:
+ data: any, logical_shape: any, dim_partition_dict: Dict[int, List[int]]
+ ) -> Union[ShardingSpec, List[ShardingSpec], None]:
"""
This is a recursive function to convert the dim partition dict to a ShardingSpec object.
"""
if isinstance(data, torch.Tensor):
dim_size = len(logical_shape)
dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict)
- sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
- entire_shape=logical_shape,
- dim_partition_dict=dim_partition_dict)
+ sharding_spec = ShardingSpec(
+ device_mesh=self.device_mesh,
+ entire_shape=logical_shape,
+ dim_partition_dict=dim_partition_dict,
+ )
return sharding_spec
elif isinstance(data, (list, tuple)):
sharding_spec = []
for data_element, logical_shape_element, dim_partition_dict_element in zip(
- data, logical_shape, dim_partition_dict):
+ data, logical_shape, dim_partition_dict
+ ):
sharding_spec.append(
- _to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element))
+ _to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element)
+ )
return sharding_spec
else:
return None
@@ -116,31 +124,41 @@ class StrategyGenerator(ABC):
results[op_data] = v
return results
- def get_communication_spec(self, sharding_spec: ShardingSpec, communication_pattern: CollectiveCommPattern,
- logical_process_axis: Union[int, List[int]]):
+ def get_communication_spec(
+ self,
+ sharding_spec: ShardingSpec,
+ communication_pattern: CollectiveCommPattern,
+ logical_process_axis: Union[int, List[int]],
+ ):
"""
A factory method to produce a CommSpec object.
"""
- return CommSpec(comm_pattern=communication_pattern,
- sharding_spec=sharding_spec,
- logical_process_axis=logical_process_axis)
-
- def get_communication_action(self,
- sharding_spec: ShardingSpec,
- communication_pattern: CollectiveCommPattern,
- logical_process_axis: Union[int, List[int]],
- comm_type: CommType,
- arg_index: int = -1,
- key_for_kwarg: any = None) -> CommAction:
+ return CommSpec(
+ comm_pattern=communication_pattern, sharding_spec=sharding_spec, logical_process_axis=logical_process_axis
+ )
+
+ def get_communication_action(
+ self,
+ sharding_spec: ShardingSpec,
+ communication_pattern: CollectiveCommPattern,
+ logical_process_axis: Union[int, List[int]],
+ comm_type: CommType,
+ arg_index: int = -1,
+ key_for_kwarg: any = None,
+ ) -> CommAction:
"""
A factory method to produce a CommAction object.
"""
- return CommAction(comm_spec=self.get_communication_spec(sharding_spec=sharding_spec,
- communication_pattern=communication_pattern,
- logical_process_axis=logical_process_axis),
- comm_type=comm_type,
- arg_index=arg_index,
- key_for_kwarg=key_for_kwarg)
+ return CommAction(
+ comm_spec=self.get_communication_spec(
+ sharding_spec=sharding_spec,
+ communication_pattern=communication_pattern,
+ logical_process_axis=logical_process_axis,
+ ),
+ comm_type=comm_type,
+ arg_index=arg_index,
+ key_for_kwarg=key_for_kwarg,
+ )
def update_communication_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
"""
@@ -155,9 +173,9 @@ class StrategyGenerator(ABC):
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
for phase, cost in num_ele_in_comm.items():
num_ele_in_comm[phase] = num_ele_in_comm[phase] * size_per_elem_bytes
- comm_cost.fwd += num_ele_in_comm['forward']
- comm_cost.bwd += num_ele_in_comm['backward']
- comm_cost.total += num_ele_in_comm['total']
+ comm_cost.fwd += num_ele_in_comm["forward"]
+ comm_cost.bwd += num_ele_in_comm["backward"]
+ comm_cost.total += num_ele_in_comm["total"]
# check if communication action exists
# if so, loop over each action and compute the cost of each action
@@ -169,8 +187,8 @@ class StrategyGenerator(ABC):
# this condition branch will be removed after all the handler updated.
comm_spec = comm_action
if isinstance(comm_spec, dict):
- src_spec = comm_spec['src_spec']
- tgt_spec = comm_spec['tgt_spec']
+ src_spec = comm_spec["src_spec"]
+ tgt_spec = comm_spec["tgt_spec"]
shape_consistency_manager = ShapeConsistencyManager()
_, comm_action_sequence, _ = shape_consistency_manager.shape_consistency(src_spec, tgt_spec)
for comm_spec_ in comm_action_sequence:
@@ -187,14 +205,12 @@ class StrategyGenerator(ABC):
"""
Customize this method to compute the computation flops.
"""
- pass
@abstractmethod
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
"""
Customize this method to compute the memory cost in bytes.
"""
- pass
def _compute_size_in_bytes(self, strategy: ShardingStrategy, key: str):
"""
@@ -212,20 +228,21 @@ class StrategyGenerator(ABC):
num_elements = 1
else:
num_elements = reduce(operator.mul, sharded_shape)
- dtype = getattr(meta_data, 'dtype')
+ dtype = getattr(meta_data, "dtype")
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
return num_elements * size_per_elem_bytes
if isinstance(op_data.data, tuple):
- assert isinstance(strategy.sharding_specs[op_data], list), \
- 'sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple.'
+ assert isinstance(
+ strategy.sharding_specs[op_data], list
+ ), "sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple."
total_bytes = 0
for index, sharding_spec in enumerate(strategy.sharding_specs[op_data]):
meta_data = op_data.data[index]
if isinstance(meta_data, torch.Tensor):
element_bytes = _compute_size_in_bytes_helper(sharding_spec, meta_data)
else:
- # if meta_data is not a tensor, we count the memroy as 0
+ # if meta_data is not a tensor, we count the memory as 0
element_bytes = 0
total_bytes += element_bytes
@@ -233,7 +250,7 @@ class StrategyGenerator(ABC):
if isinstance(op_data.data, torch.Tensor):
total_bytes = _compute_size_in_bytes_helper(strategy.sharding_specs[op_data], op_data.data)
else:
- # if op_data.data is not a tensor, we count the memroy as 0
+ # if op_data.data is not a tensor, we count the memory as 0
total_bytes = 0
return total_bytes
@@ -270,7 +287,6 @@ class StrategyGenerator(ABC):
Validate if the operands are of desired shape.
If True, means this generator can be used for the current operation.
"""
- pass
class FollowingStrategyGenerator(StrategyGenerator):
@@ -280,8 +296,9 @@ class FollowingStrategyGenerator(StrategyGenerator):
TODO: remove the original strategy_generator.py after refactoring
"""
- def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
- predecessor_node: Node):
+ def __init__(
+ self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, predecessor_node: Node
+ ):
self.op_data = operation_data_mapping
self.device_mesh = device_mesh
self.predecessor_node = predecessor_node
@@ -292,7 +309,8 @@ class OutputStrategyGenerator(StrategyGenerator):
OutputStrategyGenerator is used to generate the sharding strategies for Output Node.
"""
- def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
- predecessor_nodes: List[Node]):
+ def __init__(
+ self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, predecessor_nodes: List[Node]
+ ):
super().__init__(operation_data_mapping, device_mesh)
self.predecessor_nodes = predecessor_nodes
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py
index a0fbc58d70c0feba2c78305fb14d9bcb38a82e41..dcbf34cfd65b07cafc07f74b10bfa1d66057a3a2 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py
@@ -4,22 +4,9 @@ from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- CommAction,
- CommType,
- MemoryCost,
- ShardingStrategy,
- TrainCycleItem,
-)
-from colossalai.auto_parallel.tensor_shard.utils import (
- check_keep_sharding_status,
- detect_reshape_mapping,
- infer_output_dim_partition_dict,
-)
-from colossalai.tensor.shape_consistency import CollectiveCommPattern
-from colossalai.tensor.sharding_spec import ShardingSpec
-
-__all__ = ['SumGenerator']
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
+
+__all__ = ["SumGenerator"]
class SumGenerator(FollowingStrategyGenerator):
@@ -31,24 +18,24 @@ class SumGenerator(FollowingStrategyGenerator):
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
- sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
- sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
+ sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device()
+ sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device()
input_size_product = reduce(operator.mul, sharded_input_shape)
output_size_product = reduce(operator.mul, sharded_output_shape)
- compute_cost = TrainCycleItem(fwd=input_size_product,
- bwd=output_size_product,
- total=input_size_product + output_size_product)
+ compute_cost = TrainCycleItem(
+ fwd=input_size_product, bwd=output_size_product, total=input_size_product + output_size_product
+ )
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
+ """
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
@@ -66,8 +53,9 @@ class SumGenerator(FollowingStrategyGenerator):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@@ -78,7 +66,7 @@ class SumGenerator(FollowingStrategyGenerator):
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
- sum_dims, sum_mapping_dict = self.op_data['sum_info'].data
+ sum_dims, sum_mapping_dict = self.op_data["sum_info"].data
# TODO: a better way to handle the distributed sum is sum all the data on chip and then do all reduce
# among all the shard groups
@@ -90,7 +78,7 @@ class SumGenerator(FollowingStrategyGenerator):
elif dim in sum_mapping_dict:
dim_partition_dict_for_output[sum_mapping_dict[dim]] = dim_partition_dict_for_input[dim]
else:
- raise RuntimeError(f'dim {dim} is not in sum_mapping_dict or sum_dims')
+ raise RuntimeError(f"dim {dim} is not in sum_mapping_dict or sum_dims")
for dim in recover_dims:
dim_partition_dict_for_input.pop(dim)
@@ -105,9 +93,11 @@ class SumGenerator(FollowingStrategyGenerator):
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py
index 93cfc9eeea532ac4383f0821008deeccb13951d0..eea00c2fa064093f404f267baa3d24f3fc4c909c 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py
@@ -1,19 +1,10 @@
-import copy
from typing import List
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- CommAction,
- CommType,
- MemoryCost,
- ShardingStrategy,
- TrainCycleItem,
-)
-from colossalai.tensor.shape_consistency import CollectiveCommPattern
-from colossalai.tensor.sharding_spec import ShardingSpec
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from .strategy_generator import StrategyGenerator
-__all__ = ['TensorConstructorGenerator']
+__all__ = ["TensorConstructorGenerator"]
class TensorConstructorGenerator(StrategyGenerator):
@@ -30,10 +21,10 @@ class TensorConstructorGenerator(StrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
- forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")}
+ """
+ forward_size_mapping = {"output": self._compute_size_in_bytes(strategy, "output")}
# compute fwd cost incurred
# fwd_cost = input + output
@@ -57,11 +48,13 @@ class TensorConstructorGenerator(StrategyGenerator):
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
- name = 'Replica Tensor Constructor'
+ name = "Replica Tensor Constructor"
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py
index b867a30686eb97a55096895d344dcc28b51f347a..943cf3f1f50db8e37af9e02539b3269565e03bfd 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py
@@ -1,11 +1,11 @@
import copy
from typing import List
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from .strategy_generator import FollowingStrategyGenerator
-__all__ = ['UnaryElementwiseGenerator']
+__all__ = ["UnaryElementwiseGenerator"]
class UnaryElementwiseGenerator(FollowingStrategyGenerator):
@@ -21,12 +21,12 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
+ """
forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "input": self._compute_size_in_bytes(strategy, "input"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
@@ -44,8 +44,9 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost)
+ total_mem_cost = MemoryCost(
+ activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost
+ )
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@@ -69,9 +70,11 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
strategy_list.append(strategy)
return strategy_list
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py
index fa941f2cc51dc4d817bfc8f49c54bbaf7a8a5407..b27b4f3d40569ccc0e3e3d2a95f0146706145976 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py
@@ -10,7 +10,7 @@ from colossalai.auto_parallel.tensor_shard.utils import (
from .strategy_generator import StrategyGenerator
-__all__ = ['WhereGenerator']
+__all__ = ["WhereGenerator"]
class WhereGenerator(StrategyGenerator):
@@ -26,14 +26,14 @@ class WhereGenerator(StrategyGenerator):
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
- '''
+ """
Compute the memory cost per device with this specific strategy.
- '''
+ """
forward_size_mapping = {
- 'condition': self._compute_size_in_bytes(strategy, "condition"),
- 'x': self._compute_size_in_bytes(strategy, "x"),
- 'y': self._compute_size_in_bytes(strategy, "y"),
- 'output': self._compute_size_in_bytes(strategy, "output")
+ "condition": self._compute_size_in_bytes(strategy, "condition"),
+ "x": self._compute_size_in_bytes(strategy, "x"),
+ "y": self._compute_size_in_bytes(strategy, "y"),
+ "output": self._compute_size_in_bytes(strategy, "output"),
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
@@ -59,7 +59,7 @@ class WhereGenerator(StrategyGenerator):
"condition": dim_partition,
"x": dim_partition,
"y": dim_partition,
- "output": dim_partition
+ "output": dim_partition,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
@@ -67,9 +67,11 @@ class WhereGenerator(StrategyGenerator):
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["condition"].sharding_sequence} x {sharding_spec_mapping["x"].sharding_sequence} x {sharding_spec_mapping["y"].sharding_sequence}'
communication_action_mapping = {}
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
+ strategy = self.get_sharding_strategy(
+ name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping,
+ )
return strategy
@@ -84,9 +86,9 @@ class WhereGenerator(StrategyGenerator):
return dim_partition_list
def collate_strategies(self) -> List[ShardingStrategy]:
- '''
+ """
Generate every possible strategies for a where node, and record all strategies into the strategies_vector.
- '''
+ """
strategy_list = []
dimension_length = len(self.op_data["output"].logical_shape)
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py
index 86f90694e0604f72e9564020ccab455cfdee29a0..5b4ea0afe5f8157a5d63e7ac3bd8621c626fd6a1 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py
@@ -7,7 +7,7 @@ from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, SumGenerator
-__all__ = ['SumHandler']
+__all__ = ["SumHandler"]
@operator_registry.register(torch.Tensor.sum)
@@ -55,7 +55,7 @@ class SumHandler(NodeHandler):
# sum_mapping_dict[1] = 0 means the 0th dim of output is the 1st dim of input
# sum_mapping_dict[3] = 1 means the 1st dim of output is the 3rd dim of input
sum_mapping_dict = {}
- if 'keepdim' in self.node.kwargs and self.node.kwargs['keepdim']:
+ if "keepdim" in self.node.kwargs and self.node.kwargs["keepdim"]:
for i in range(num_dims):
sum_mapping_dict.update({i: i})
else:
@@ -67,7 +67,7 @@ class SumHandler(NodeHandler):
assert output_index == self.node._meta_data.dim()
sum_info = (sum_dims, sum_mapping_dict)
- physical_shape_operand = OperationData(name='sum_info', type=OperationDataType.ARG, data=sum_info)
+ physical_shape_operand = OperationData(name="sum_info", type=OperationDataType.ARG, data=sum_info)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
@@ -75,7 +75,7 @@ class SumHandler(NodeHandler):
mapping = {
"input": physical_input_operand,
"sum_info": physical_shape_operand,
- "output": physical_output_operand
+ "output": physical_output_operand,
}
return mapping
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py
index 855a2e7612af0cb59cae9bc8574197fad098f983..c2aa120e8a282ff346c4bf8dc6214d9f608b470c 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py
@@ -8,7 +8,7 @@ from .registry import operator_registry
from .strategy import StrategyGenerator
from .strategy.tensor_constructor_generator import TensorConstructorGenerator
-__all__ = ['TensorConstructorHandler']
+__all__ = ["TensorConstructorHandler"]
@operator_registry.register(torch.arange)
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py
index 7a9d377264905a650fd991cb10e98f8c3f16f871..b72d9812f4062b5b64c04726d2970c9eb4f30f13 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py
@@ -7,7 +7,7 @@ from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, TransposeGenerator
-__all__ = ['TransposeHandler']
+__all__ = ["TransposeHandler"]
@operator_registry.register(torch.Tensor.transpose)
@@ -48,9 +48,9 @@ class TransposeHandler(NodeHandler):
if transpose_dims[i] < 0:
transpose_dims[i] += num_dims
- physical_shape_operand = OperationData(name='transpose_dims',
- type=OperationDataType.ARG,
- data=list(transpose_dims))
+ physical_shape_operand = OperationData(
+ name="transpose_dims", type=OperationDataType.ARG, data=list(transpose_dims)
+ )
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
@@ -58,7 +58,7 @@ class TransposeHandler(NodeHandler):
mapping = {
"input": physical_input_operand,
"transpose_dims": physical_shape_operand,
- "output": physical_output_operand
+ "output": physical_output_operand,
}
return mapping
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py
index 0362de780d7af0fa9569a575a122f84ffb42b0db..cbc873de822380df426c806d0ac8c25ca7b1df53 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py
@@ -3,11 +3,11 @@ from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
-from .node_handler import MetaInfoNodeHandler, NodeHandler
+from .node_handler import MetaInfoNodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, UnaryElementwiseGenerator
-__all__ = ['UnaryElementwiseHandler']
+__all__ = ["UnaryElementwiseHandler"]
@operator_registry.register(torch.Tensor.to)
@@ -33,9 +33,9 @@ class UnaryElementwiseHandler(MetaInfoNodeHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
- physical_input_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data)
+ physical_input_operand = OperationData(
+ name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
+ )
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "output": physical_output}
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py
index 7dff89d1d7a39a6e4fe73514bfb16abe2e3e7bea..56c1d10a167e6bd4e295fc29805530bc321e5933 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py
@@ -7,7 +7,7 @@ from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, ViewGenerator
-__all__ = ['ViewHandler']
+__all__ = ["ViewHandler"]
@operator_registry.register(torch.Tensor.reshape)
@@ -38,7 +38,7 @@ class ViewHandler(NodeHandler):
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
target_shape = self.node._meta_data.shape
- physical_shape_operand = OperationData(name='tgt_shape', type=OperationDataType.ARG, data=target_shape)
+ physical_shape_operand = OperationData(name="tgt_shape", type=OperationDataType.ARG, data=target_shape)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
@@ -46,7 +46,7 @@ class ViewHandler(NodeHandler):
mapping = {
"input": physical_input_operand,
"tgt_shape": physical_shape_operand,
- "output": physical_output_operand
+ "output": physical_output_operand,
}
return mapping
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py
index 6de2aaafdd018f08195563ef882f07eb39d8d20a..1856a11100b07944d99d498e19f5cd8dda25d9d5 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py
@@ -1,16 +1,15 @@
import copy
-import operator
from typing import Dict, List
import torch
-from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
+from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, WhereGenerator
-__all__ = ['WhereHandler']
+__all__ = ["WhereHandler"]
@operator_registry.register(torch.where)
@@ -28,27 +27,28 @@ class WhereHandler(NodeHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
- physical_condition_operand = OperationData(name=str(self.node.args[0]),
- type=OperationDataType.ARG,
- data=self.node.args[0]._meta_data)
- physical_x_operand = OperationData(name=str(self.node.args[1]),
- type=OperationDataType.ARG,
- data=self.node.args[1]._meta_data)
- physical_y_operand = OperationData(name=str(self.node.args[2]),
- type=OperationDataType.ARG,
- data=self.node.args[2]._meta_data)
+ physical_condition_operand = OperationData(
+ name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data
+ )
+ physical_x_operand = OperationData(
+ name=str(self.node.args[1]), type=OperationDataType.ARG, data=self.node.args[1]._meta_data
+ )
+ physical_y_operand = OperationData(
+ name=str(self.node.args[2]), type=OperationDataType.ARG, data=self.node.args[2]._meta_data
+ )
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
physical_mapping = {
"condition": physical_condition_operand,
"x": physical_x_operand,
"y": physical_y_operand,
- "output": physical_output
+ "output": physical_output,
}
logical_shape_for_all = self.node._meta_data.shape
logical_mapping = {}
for key, physical_operand in physical_mapping.items():
- logical_mapping[key] = self.convert_physical_operand_to_logical_operand(physical_operand,
- logical_shape_for_all)
+ logical_mapping[key] = self.convert_physical_operand_to_logical_operand(
+ physical_operand, logical_shape_for_all
+ )
return logical_mapping, physical_mapping
@@ -64,7 +64,8 @@ class WhereHandler(NodeHandler):
logical_shape = logical_op_data_mapping[key].logical_shape
physical_shape = physical_op_data_mapping[key].logical_shape
physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
- logical_sharding_spec, logical_shape, physical_shape)
+ logical_sharding_spec, logical_shape, physical_shape
+ )
strategy.sharding_specs.pop(logical_op_data_mapping[key])
strategy.sharding_specs[physical_op_data_mapping[key]] = physical_sharding_spec
strategy.name = f"{strategy.sharding_specs[physical_op_data_mapping['output']].sharding_sequence} = {strategy.sharding_specs[physical_op_data_mapping['condition']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['x']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['y']].sharding_sequence}"
diff --git a/colossalai/auto_parallel/tensor_shard/options.py b/colossalai/auto_parallel/tensor_shard/options.py
index f0ea502a6f0e2c4412ac333f9465aec6873e9791..e87872f39c1044f159f474644fe02f3b822dcfdd 100644
--- a/colossalai/auto_parallel/tensor_shard/options.py
+++ b/colossalai/auto_parallel/tensor_shard/options.py
@@ -1,13 +1,14 @@
from dataclasses import dataclass
from enum import Enum
-__all__ = ['SolverOptions', 'SolverPerference', 'DataloaderOption', 'ShardOption']
+__all__ = ["SolverOptions", "SolverPerference", "DataloaderOption", "ShardOption"]
class SolverPerference(Enum):
"""
This enum class is to define the solver preference.
"""
+
STANDARD = 0
DP = 1
TP = 2
@@ -25,6 +26,7 @@ class ShardOption(Enum):
TP_SHARD: We require the node to be shard using tensor parallel strategies on last device mesh axis.
TP_FULL_SHARD: We require the node to be shard using tensor parallel strategies on all device mesh axes.
"""
+
STANDARD = 0
SHARD = 1
SHARD_LAST_AXIS = 2
@@ -35,6 +37,7 @@ class DataloaderOption(Enum):
"""
This enum class is to define the dataloader option.
"""
+
REPLICATED = 0
DISTRIBUTED = 1
@@ -44,6 +47,7 @@ class SolverOptions:
"""
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
"""
+
solver_perference: SolverPerference = SolverPerference.STANDARD
dataloader_option: DataloaderOption = DataloaderOption.REPLICATED
shard_option: ShardOption = ShardOption.STANDARD
diff --git a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py
index 6af92727243759bf2d0e0e1b8f472e1e59308ca3..8e22df64d86846cc0914fe81f2dc4dca09d4269c 100644
--- a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py
+++ b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py
@@ -10,7 +10,6 @@ from colossalai.tensor.comm_spec import CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import (
- BCAST_FUNC_OP,
ELEMENTWISE_FUNC_OP,
ELEMENTWISE_METHOD_OP,
ELEMENTWISE_MODULE_OP,
@@ -18,13 +17,14 @@ from .constants import (
RESHAPE_METHOD_OP,
)
-__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector']
+__all__ = ["OperationDataType", "OperationData", "TrainCycleItem", "MemoryCost", "ShardingStrategy", "StrategiesVector"]
class OperationDataType(Enum):
"""
An operation can come from the argument list of an operator or the parameter list of a module.
"""
+
INPUT = 0
ARG = 1
PARAM = 2
@@ -43,6 +43,7 @@ class OperationData:
data (Any): the value for this data, usually it is a meta tensor.
logical_shape (Tuple[int]): the logical shape of the data, it can be different from the its actual shape in memory.
"""
+
name: str
type: OperationDataType
data: Any
@@ -69,13 +70,13 @@ class OperationData:
self.logical_shape = _infer_logical_shape(self.data)
def __repr__(self) -> str:
- return f'OperationData(name={self.name}, type={self.type})'
+ return f"OperationData(name={self.name}, type={self.type})"
def __eq__(self, other) -> bool:
return other.name == self.name
def __hash__(self) -> int:
- return hash(f'{self.name}')
+ return hash(f"{self.name}")
@dataclass
@@ -88,6 +89,7 @@ class TrainCycleItem:
fwd (float): the item for the forward pass
bwd (float): the item for the backward pass
"""
+
fwd: Any
bwd: Any
total: Any
@@ -104,6 +106,7 @@ class MemoryCost:
temp (int): the memory cost incurred by the temporary tensors in bytes.
buffer (int): the memory cost incurred by the module buffer in bytes.
"""
+
activation: int = 0
parameter: int = 0
temp: int = 0
@@ -120,6 +123,7 @@ class CommType(Enum):
HOOK: the communication action is used to do the grad all reduce.
IMPLICIT: the communication action happens during the kernel execution, such as SyncBatchNorm
"""
+
BEFORE = 0
AFTER = 1
HOOK = 2
@@ -137,6 +141,7 @@ class CommAction:
arg_index: record the location of tensor which join the communication, we cannot use name of node or op_data at runtime,
because the args of node may be changed by graph transform passes.
"""
+
comm_spec: CommSpec = None
comm_type: CommType = None
arg_index: int = -1
@@ -156,6 +161,7 @@ class ShardingStrategy:
memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None)
input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes.
"""
+
name: str
sharding_specs: Dict[OperationData, Union[ShardingSpec, Tuple[ShardingSpec]]] = None
compute_cost: TrainCycleItem = None
@@ -200,7 +206,6 @@ class ShardingStrategy:
raise KeyError(f"Could not find the ShardingSpec for OperationData with name {name}")
def clone(self):
-
def _deepcopy_dict_vals(data: Dict):
return {k: deepcopy(v) for k, v in data.items()}
@@ -209,31 +214,34 @@ class ShardingStrategy:
# Consider the examples below:
# If self.communication_actions is an empty dictionary {}, then self.communication_actions is not None, but its __bool__ value is False.
# In this case, if we set None to the new object, program will crash when we try to access the communication_actions.items.
- communication_actions = _deepcopy_dict_vals(
- self.communication_actions) if self.communication_actions is not None else None
+ communication_actions = (
+ _deepcopy_dict_vals(self.communication_actions) if self.communication_actions is not None else None
+ )
# same reason as communication_actions
resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs is not None else None
compute_cost = deepcopy(self.compute_cost)
communication_cost = deepcopy(self.communication_cost)
memory_cost = deepcopy(self.memory_cost)
- return ShardingStrategy(name=self.name,
- sharding_specs=sharding_specs,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- communication_actions=communication_actions,
- resharding_costs=resharding_costs)
+ return ShardingStrategy(
+ name=self.name,
+ sharding_specs=sharding_specs,
+ compute_cost=compute_cost,
+ communication_cost=communication_cost,
+ memory_cost=memory_cost,
+ communication_actions=communication_actions,
+ resharding_costs=resharding_costs,
+ )
class StrategiesVector(list):
- '''
+ """
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
strategies of the node.
Argument:
node (Node): node for which the list of sharding strategies are generated.
- '''
+ """
def __init__(self, node: Node):
super().__init__()
@@ -245,7 +253,7 @@ class StrategiesVector(list):
def check_merge(self):
merge_label = False
- if self.node.op == 'call_module':
+ if self.node.op == "call_module":
target = self.node.target
root_module = self.node.graph.owning_module
submod = root_module.get_submodule(target)
@@ -255,7 +263,7 @@ class StrategiesVector(list):
if submod_type in ELEMENTWISE_MODULE_OP:
merge_label = True
- if self.node.op == 'call_function':
+ if self.node.op == "call_function":
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
if self.node.target in ELEMENTWISE_FUNC_OP:
merge_label = True
@@ -267,7 +275,7 @@ class StrategiesVector(list):
if self.node.target in RESHAPE_FUNC_OP:
merge_label = True
- if self.node.op == 'call_method':
+ if self.node.op == "call_method":
# we could merge reshape op, because their computation costs are negligible.
method = getattr(self.node.args[0]._meta_data.__class__, self.node.target)
if method in RESHAPE_METHOD_OP:
diff --git a/colossalai/auto_parallel/tensor_shard/solver/__init__.py b/colossalai/auto_parallel/tensor_shard/solver/__init__.py
index f9e6bd9239214c4def03b1b419a2845581fa083a..b930ce80a9b932e6c54680fd7c988428b29afa0e 100644
--- a/colossalai/auto_parallel/tensor_shard/solver/__init__.py
+++ b/colossalai/auto_parallel/tensor_shard/solver/__init__.py
@@ -3,4 +3,4 @@ from .graph_analysis import GraphAnalyser
from .solver import Solver
from .strategies_constructor import StrategiesConstructor
-__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph']
+__all__ = ["GraphAnalyser", "Solver", "StrategiesConstructor", "CostGraph"]
diff --git a/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py
index 74290453ca0c2dd40008ec51584a134c8f278410..4415d429b0c22b82631984223df43f751a5fe1f6 100644
--- a/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py
+++ b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py
@@ -4,18 +4,18 @@ from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
class CostGraph:
- '''
+ """
A graph data structure to simplify the edge cost graph. It has two main functions:
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
2. To reduce the searching space, we merge computationally-trivial operators, such as
- element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will
+ element-wise operators, transpose, and reduction, into their following nodes. The merging information will
be given by the StrategiesVector depending on the type of target node and following nodes.
Argument:
leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph.
simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True)
- '''
+ """
def __init__(self, leaf_strategies, simplify=True, forward_only=False):
self.leaf_strategies = leaf_strategies
@@ -39,10 +39,10 @@ class CostGraph:
target_node_list.remove(element)
def _build_cost_graph(self):
- '''
+ """
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
set to node.
- '''
+ """
self.edge_costs = {}
if self.simplify:
self.merge_pair = []
@@ -84,13 +84,13 @@ class CostGraph:
if _check_tensor_in_node(node._meta_data):
children_nodes.append(node)
- setattr(dst_node, 'parents', parent_nodes)
- setattr(dst_node, 'children', children_nodes)
+ setattr(dst_node, "parents", parent_nodes)
+ setattr(dst_node, "children", children_nodes)
if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes:
# we only merge node pairs which src node has a tensor element inside.
- # This is necessay because the node without a tensor element inside will not
+ # This is necessary because the node without a tensor element inside will not
# be assigned any strategy.
if _check_tensor_in_node(followed_node._meta_data):
self.merge_pair.append((followed_node, dst_node))
@@ -99,7 +99,7 @@ class CostGraph:
return self.edge_costs[(src_node, dst_node)]
def merge_node(self, src_node, dst_node):
- '''
+ """
To merge dst_node into src_node, we need to do it in following steps:
1. For each strategy in dst_node, we need to pick an appropriate strategy
@@ -119,7 +119,7 @@ class CostGraph:
Argument:
src_node(Node): The node will be merged into dst_node.
dst_node(Node): The node to integrate src_node.
- '''
+ """
# build merge_map
merge_map = {}
for src_index, _ in enumerate(src_node.strategies_vector):
@@ -196,7 +196,7 @@ class CostGraph:
if not self.simplify:
return
self.merge_pair.reverse()
- for (src_node, dst_node) in self.merge_pair:
+ for src_node, dst_node in self.merge_pair:
self.merge_node(src_node, dst_node)
self.merge_pair.reverse()
reindexing_following_dict = {}
diff --git a/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py b/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py
index be39a74cb23755f9ff2b83cf1123a7a7f9708ffa..678965d663e4d243f0d63771f82455cc89f032ec 100644
--- a/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py
+++ b/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py
@@ -7,7 +7,7 @@ from torch.fx.node import Node
from colossalai.fx.passes.utils import get_node_module
-__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser']
+__all__ = ["LiveVariable", "LiveVariableVector", "LiveStage", "GraphAnalyser"]
@dataclass
@@ -15,6 +15,7 @@ class LiveVariable:
"""
LiveVariable is a data structure to store the meta information of a variable for liveness analysis.
"""
+
name: str
node: Node
is_inplace: bool
@@ -55,6 +56,7 @@ class LiveStage:
"""
LiveStage is a data structure to record the living variables at this current node.
"""
+
name: str
node: Node
all_live_vars: LiveVariableVector
@@ -62,7 +64,6 @@ class LiveStage:
class GraphAnalyser:
-
def __init__(self, gm: GraphModule):
self._gm = gm
self._graph = gm.graph
@@ -83,7 +84,7 @@ class GraphAnalyser:
def liveness_analysis(self) -> List[LiveStage]:
"""
- Analyse the graph to obtain the variable liveness information. This function returns
+ Analyses the graph to obtain the variable liveness information. This function returns
an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object.
"""
compute_nodes = self.graph.nodes
@@ -91,7 +92,7 @@ class GraphAnalyser:
# checked: record all variables created since the first stage
# all: record the live variables only exist until the current stage.
- # this can be different from the `checked list`` as some varialbes may be destroyed prior to this stage.
+ # this can be different from the `checked list`` as some variables may be destroyed prior to this stage.
# unique: record the unique live variables only exist until the current stage.
# this is different from `all list` as some variables are duplicated.
checked_variables = LiveVariableVector()
@@ -103,20 +104,20 @@ class GraphAnalyser:
# find new living variables #
#############################
# detect whether the current op is an in-place op
- # if it is an in-place op, we would deem it as a duplciate var
+ # if it is an in-place op, we would deem it as a duplicate var
is_inplace = False
- if node.op == 'call_function':
+ if node.op == "call_function":
# check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True)
- if node.kwargs.get('inplace', False):
+ if node.kwargs.get("inplace", False):
is_inplace = True
- elif node.op == 'call_module':
+ elif node.op == "call_module":
# to check if this is an inplace op such as torch.nn.Relu(inplace=True)
module = get_node_module(node)
- if getattr(module, 'inplace', False):
+ if getattr(module, "inplace", False):
is_inplace = True
# add the output var
- meta = getattr(node, '_meta_data', None)
+ getattr(node, "_meta_data", None)
live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace)
if not is_inplace:
unique_live_vars.append(live_var)
@@ -138,10 +139,12 @@ class GraphAnalyser:
# this should be completed if we are able to trace the backward compute graph
# add this stage to liveness dict
- stage = LiveStage(name=node.name,
- node=node,
- all_live_vars=all_live_variables.copy(),
- unique_live_vars=unique_live_vars.copy())
+ stage = LiveStage(
+ name=node.name,
+ node=node,
+ all_live_vars=all_live_variables.copy(),
+ unique_live_vars=unique_live_vars.copy(),
+ )
# if a LiveStage is covered by another LiveStage, we just keep the larger one.
replace = False
for index, prev_stage in enumerate(liveness_list):
diff --git a/colossalai/auto_parallel/tensor_shard/solver/solver.py b/colossalai/auto_parallel/tensor_shard/solver/solver.py
index f5c6663dce80671199dcaf235380178622813310..088d1acb51778acf4bd220ae777811f54cb1eb50 100644
--- a/colossalai/auto_parallel/tensor_shard/solver/solver.py
+++ b/colossalai/auto_parallel/tensor_shard/solver/solver.py
@@ -21,34 +21,35 @@ try:
import pulp
from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum
except:
- warnings.warn(f'please install the pulp')
+ warnings.warn(f"please install the pulp")
-__all___ = ['Solver']
+__all___ = ["Solver"]
class Solver:
-
- def __init__(self,
- graph: Graph,
- strategies_constructor: StrategiesConstructor,
- cost_graph: CostGraph,
- graph_analyser: GraphAnalyser = None,
- memory_budget: float = -1.0,
- solution_numbers: int = 1,
- forward_only: bool = False,
- memory_increasing_coefficient: float = 1.3,
- verbose=False):
- '''
+ def __init__(
+ self,
+ graph: Graph,
+ strategies_constructor: StrategiesConstructor,
+ cost_graph: CostGraph,
+ graph_analyser: GraphAnalyser = None,
+ memory_budget: float = -1.0,
+ solution_numbers: int = 1,
+ forward_only: bool = False,
+ memory_increasing_coefficient: float = 1.3,
+ verbose=False,
+ ):
+ """
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
Argument:
graph: The computing graph to be optimized.
strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
cost_graph: A graph data structure to simplify the edge cost graph.
- graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints.
+ graph_analyser: graph_analyser will analyses the graph to obtain the variable liveness information, which will be used to generate memory constraints.
memory_budget: Memory constraint for the solution.
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
- '''
+ """
self.graph = graph
self.strategies_constructor = strategies_constructor
self.cost_graph = cost_graph
@@ -75,11 +76,11 @@ class Solver:
self.verbose = verbose
def _recover_merged_node_strategy(self):
- '''
+ """
During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node.
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
node.
- '''
+ """
for node_index, node in enumerate(self.nodes):
if node.strategies_vector.check_merge():
# the merged node has only one input, and its strategies follow the input sharding strategy
@@ -98,9 +99,9 @@ class Solver:
return node_index_dict
def _prepare_data_for_solver(self):
- '''
+ """
Extract information from components for solver.
- '''
+ """
node_nums = len(self.leaf_strategies)
memory_budget = self.memory_budget
@@ -190,23 +191,40 @@ class Solver:
# omit initial value for nodes
s_init_np = None
- return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np, self.verbose
-
- def _call_solver_serialized_args(self,
- node_nums,
- memory_budget,
- strategies_len,
- following_nodes,
- edge_pairs,
- alias_set,
- liveness_set,
- compute_costs,
- communication_costs,
- memory_costs,
- resharding_costs,
- alias_convert_costs,
- s_init_np=None,
- verbose=True):
+ return (
+ node_nums,
+ memory_budget,
+ strategies_len,
+ following_nodes,
+ edge_pairs,
+ alias_set,
+ liveness_set,
+ compute_costs,
+ communication_costs,
+ memory_costs,
+ resharding_costs,
+ alias_convert_costs,
+ s_init_np,
+ self.verbose,
+ )
+
+ def _call_solver_serialized_args(
+ self,
+ node_nums,
+ memory_budget,
+ strategies_len,
+ following_nodes,
+ edge_pairs,
+ alias_set,
+ liveness_set,
+ compute_costs,
+ communication_costs,
+ memory_costs,
+ resharding_costs,
+ alias_convert_costs,
+ s_init_np=None,
+ verbose=True,
+ ):
"""
Call the solver with serialized arguments.
"""
@@ -235,18 +253,18 @@ class Solver:
s_follow = following_nodes
s_alias = alias_set
- E = edge_pairs.reshape((-1, 2)) # noqa
+ E = edge_pairs.reshape((-1, 2)) # noqa
r = []
pt = 0
edge_set = set()
- for (i, j) in E:
+ for i, j in E:
prod_length = strategies_len[i] * strategies_len[j]
if (i, j) in edge_set:
raise ValueError(f"Duplicated edges: {(i, j)}")
edge_set.add((i, j))
- r.append(resharding_costs[pt:pt + prod_length])
+ r.append(resharding_costs[pt : pt + prod_length])
pt += prod_length
assert pt == len(resharding_costs)
@@ -268,7 +286,6 @@ class Solver:
# L.append(liveness_set[pt:pt + length])
# pt += length
# assert pt == len(liveness_set)
- v = []
pt = 0
c = []
@@ -277,9 +294,9 @@ class Solver:
pt = 0
for i in range(node_nums):
length = strategies_len[i]
- c.append(compute_costs[pt:pt + length])
- d.append(communication_costs[pt:pt + length])
- m.append(memory_costs[pt:pt + length])
+ c.append(compute_costs[pt : pt + length])
+ d.append(communication_costs[pt : pt + length])
+ m.append(memory_costs[pt : pt + length])
pt += length
assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}"
@@ -319,7 +336,7 @@ class Solver:
e = []
num_edges = 0
map_edge_to_idx = {}
- for (idx, (i, j)) in enumerate(E):
+ for idx, (i, j) in enumerate(E):
if len(s[i]) == 1:
e.append(s[j])
elif len(s[j]) == 1:
@@ -340,7 +357,7 @@ class Solver:
######################################
if s_init_np is not None:
s_init = s_init_np.reshape((-1, 3))
- for (idx, value, fix) in s_init:
+ for idx, value, fix in s_init:
for i in range(len(s[idx])):
s[idx][i].setInitialValue(i == value)
if fix:
@@ -393,7 +410,7 @@ class Solver:
# (d). specified by `cat="Binary"`
- for (idx, (i, j)) in enumerate(E):
+ for idx, (i, j) in enumerate(E):
if strategies_len[i] == 1 or strategies_len[j] == 1:
continue
@@ -402,13 +419,13 @@ class Solver:
# (f)
for row in range(len(s[i])):
- C = len(s[j]) # noqa
+ C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row]
# (g)
for col in range(len(s[j])):
- R = len(s[i]) # noqa
- C = len(s[j]) # noqa
+ R = len(s[i]) # noqa
+ C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col]
# (h)
@@ -434,7 +451,8 @@ class Solver:
msg = verbose
time_limit = 600
assert "COIN_CMD" in pulp.listSolvers(
- onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'")
+ onlyAvailable=True
+ ), "Please install ILP solvers by 'sudo apt install coinor-cbc'"
solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count())
# solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit)
@@ -444,13 +462,13 @@ class Solver:
objective = pulp.value(prob.objective)
objective = float(objective) if objective is not None else -1.0
if verbose:
- print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t"
- f"Time: {time.time() - tic}")
+ print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t" f"Time: {time.time() - tic}")
print(f"#nodes: {num_nodes}, #edges: {num_edges}")
if prob.status in [pulp.LpStatusInfeasible]:
- raise RuntimeError("Cannot run the function under the given memory budget. "
- "Please increase the memory budget.")
+ raise RuntimeError(
+ "Cannot run the function under the given memory budget. " "Please increase the memory budget."
+ )
# Get and check results
s_val = np.full((node_nums,), -1, dtype=np.int32)
@@ -458,7 +476,7 @@ class Solver:
s_val[i] = get_non_zero_index(s[i])
e_val = np.full((len(E),), -1, dtype=np.int32)
- for (idx, (i, j)) in enumerate(E):
+ for idx, (i, j) in enumerate(E):
e_val[idx] = get_non_zero_index(e[idx])
i_spec_index = e_val[idx] // len(s[j])
j_spec_index = e_val[idx] % len(s[j])
diff --git a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
index 044a8ac847ead4b6b7d9f05c3d19a43a8fc2346c..aa87ee9bf3db1fc52d5d9b8289c9d7ae7c01ca93 100644
--- a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
+++ b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
@@ -1,11 +1,5 @@
-import builtins
-import math
-import operator
-from copy import deepcopy
-from typing import Dict, List
-
import torch
-from torch.fx import Graph, Node
+from torch.fx import Graph
from colossalai.auto_parallel.tensor_shard.node_handler import (
GetattrHandler,
@@ -14,13 +8,12 @@ from colossalai.auto_parallel.tensor_shard.node_handler import (
operator_registry,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
-from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
from colossalai.device.device_mesh import DeviceMesh
from ..options import DataloaderOption, SolverOptions
-__all__ = ['StrategiesConstructor']
+__all__ = ["StrategiesConstructor"]
class StrategiesConstructor:
@@ -35,7 +28,7 @@ class StrategiesConstructor:
def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
self.graph = graph
- assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
+ assert graph.owning_module is not None, "The given graph is not associated with a owning_module"
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.device_mesh = device_mesh
@@ -46,11 +39,11 @@ class StrategiesConstructor:
self.alias_set = None
def remove_duplicated_strategy(self, strategies_vector):
- '''
+ """
In build_strategies_and_cost method, we may produce some duplicated strategies.
In this method, we will remove the duplicated strategies depending on the strategies name.
Note that this operation is in-place.
- '''
+ """
name_checklist = []
remove_list = []
for strategy in strategies_vector:
@@ -62,7 +55,6 @@ class StrategiesConstructor:
strategies_vector.remove(strategy)
def generate_alias_set(self):
-
node_list = [strategy_vector.node for strategy_vector in self.leaf_strategies]
common_blocks = find_repeat_blocks(node_list, self.root_module, common_length_threshold=10)
@@ -83,7 +75,7 @@ class StrategiesConstructor:
"""
def _check_no_strategy_for_node(node):
- if node.op in ('placeholder', 'get_attr', 'output'):
+ if node.op in ("placeholder", "get_attr", "output"):
return False
def _check_no_strategy_for_data(data):
@@ -102,83 +94,93 @@ class StrategiesConstructor:
if _check_no_strategy_for_node(node):
self.no_strategy_nodes.append(node)
- pass
# placeholder node
- elif node.op == 'placeholder':
+ elif node.op == "placeholder":
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
- placeholder_option = 'distributed'
+ placeholder_option = "distributed"
else:
- assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
- placeholder_option = 'replicated'
- placeholder_handler = PlaceholderHandler(node,
- self.device_mesh,
- strategies_vector,
- placeholder_option=placeholder_option)
+ assert (
+ self.solver_options.dataloader_option == DataloaderOption.REPLICATED
+ ), f"placeholder_option {self.solver_options.dataloader_option} is not supported"
+ placeholder_option = "replicated"
+ placeholder_handler = PlaceholderHandler(
+ node, self.device_mesh, strategies_vector, placeholder_option=placeholder_option
+ )
placeholder_handler.register_strategy()
# get_attr node
- elif node.op == 'get_attr':
- getattr_handler = GetattrHandler(node,
- self.device_mesh,
- strategies_vector,
- shard_option=self.solver_options.shard_option,
- solver_perference=self.solver_options.solver_perference)
+ elif node.op == "get_attr":
+ getattr_handler = GetattrHandler(
+ node,
+ self.device_mesh,
+ strategies_vector,
+ shard_option=self.solver_options.shard_option,
+ solver_perference=self.solver_options.solver_perference,
+ )
getattr_handler.register_strategy()
# call_module node
- elif node.op == 'call_module':
+ elif node.op == "call_module":
target = node.target
submod = self.root_module.get_submodule(target)
submod_type = type(submod)
- handler = operator_registry.get(submod_type)(node,
- self.device_mesh,
- strategies_vector,
- shard_option=self.solver_options.shard_option,
- solver_perference=self.solver_options.solver_perference)
+ handler = operator_registry.get(submod_type)(
+ node,
+ self.device_mesh,
+ strategies_vector,
+ shard_option=self.solver_options.shard_option,
+ solver_perference=self.solver_options.solver_perference,
+ )
handler.register_strategy()
# attach strategies_info to node
- if hasattr(handler, 'strategies_info'):
- setattr(node, 'strategies_info', handler.strategies_info)
+ if hasattr(handler, "strategies_info"):
+ setattr(node, "strategies_info", handler.strategies_info)
# call_function node
- elif node.op == 'call_function':
+ elif node.op == "call_function":
target = node.target
- handler = operator_registry.get(target)(node,
- self.device_mesh,
- strategies_vector,
- shard_option=self.solver_options.shard_option,
- solver_perference=self.solver_options.solver_perference)
+ handler = operator_registry.get(target)(
+ node,
+ self.device_mesh,
+ strategies_vector,
+ shard_option=self.solver_options.shard_option,
+ solver_perference=self.solver_options.solver_perference,
+ )
handler.register_strategy()
# attach strategies_info to node
- if hasattr(handler, 'strategies_info'):
- setattr(node, 'strategies_info', handler.strategies_info)
+ if hasattr(handler, "strategies_info"):
+ setattr(node, "strategies_info", handler.strategies_info)
# call_method node
- elif node.op == 'call_method':
+ elif node.op == "call_method":
method = getattr(node.args[0]._meta_data.__class__, node.target)
- handler = operator_registry.get(method)(node,
- self.device_mesh,
- strategies_vector,
- shard_option=self.solver_options.shard_option,
- solver_perference=self.solver_options.solver_perference)
+ handler = operator_registry.get(method)(
+ node,
+ self.device_mesh,
+ strategies_vector,
+ shard_option=self.solver_options.shard_option,
+ solver_perference=self.solver_options.solver_perference,
+ )
handler.register_strategy()
# attach strategies_info to node
- if hasattr(handler, 'strategies_info'):
- setattr(node, 'strategies_info', handler.strategies_info)
+ if hasattr(handler, "strategies_info"):
+ setattr(node, "strategies_info", handler.strategies_info)
# output node
- elif node.op == 'output':
+ elif node.op == "output":
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
- output_option = 'distributed'
+ output_option = "distributed"
else:
- assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
- output_option = 'replicated'
+ assert (
+ self.solver_options.dataloader_option == DataloaderOption.REPLICATED
+ ), f"placeholder_option {self.solver_options.dataloader_option} is not supported"
+ output_option = "replicated"
output_handler = OutputHandler(node, self.device_mesh, strategies_vector, output_option=output_option)
output_handler.register_strategy()
self.remove_duplicated_strategy(strategies_vector)
- setattr(node, 'strategies_vector', strategies_vector)
+ setattr(node, "strategies_vector", strategies_vector)
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector
diff --git a/colossalai/auto_parallel/tensor_shard/utils/__init__.py b/colossalai/auto_parallel/tensor_shard/utils/__init__.py
index b7fe5430bf136b08d93706b33cf4dbf82e342013..d61cfd2add15ba096270723446ffe52768a31f31 100644
--- a/colossalai/auto_parallel/tensor_shard/utils/__init__.py
+++ b/colossalai/auto_parallel/tensor_shard/utils/__init__.py
@@ -17,9 +17,21 @@ from .sharding import (
)
__all__ = [
- 'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
- 'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
- 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
- 'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map',
- 'detect_reshape_mapping', 'check_keep_sharding_status', 'infer_output_dim_partition_dict'
+ "BroadcastType",
+ "get_broadcast_shape",
+ "is_broadcastable",
+ "recover_sharding_spec_for_broadcast_shape",
+ "generate_resharding_costs",
+ "generate_sharding_spec",
+ "ignore_sharding_exception",
+ "check_sharding_spec_validity" "transpose_partition_dim",
+ "update_partition_dim",
+ "enumerate_all_possible_1d_sharding",
+ "enumerate_all_possible_2d_sharding",
+ "generate_sharding_size",
+ "comm_actions_for_oprands",
+ "pytree_map",
+ "detect_reshape_mapping",
+ "check_keep_sharding_status",
+ "infer_output_dim_partition_dict",
]
diff --git a/colossalai/auto_parallel/tensor_shard/utils/broadcast.py b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py
index 28aa551328d7a6d5f283338fe55c90eb102d253c..99d5a0f2a9420d50afa47485b9a4b20ed8c60a2a 100644
--- a/colossalai/auto_parallel/tensor_shard/utils/broadcast.py
+++ b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py
@@ -14,14 +14,17 @@ from colossalai.tensor.comm_spec import CollectiveCommPattern, CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = [
- 'BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape',
- 'comm_actions_for_oprands'
+ "BroadcastType",
+ "is_broadcastable",
+ "get_broadcast_shape",
+ "recover_sharding_spec_for_broadcast_shape",
+ "comm_actions_for_oprands",
]
class BroadcastType(Enum):
EQUAL = auto()
- PADDDING = auto()
+ PADDING = auto()
MULTIPLE = auto()
@@ -41,7 +44,7 @@ def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]:
"""
Compute the broadcast shape given two shapes.
"""
- assert is_broadcastable(shape1, shape2), f'{shape1} and {shape2} are not broadcastable'
+ assert is_broadcastable(shape1, shape2), f"{shape1} and {shape2} are not broadcastable"
shape1_reverse = shape1[::-1]
shape2_reverse = shape2[::-1]
min_common_dim = min(len(shape1), len(shape2))
@@ -60,8 +63,9 @@ def get_broadcast_dim_info(logical_shape, physical_shape):
logical_num_dims = len(logical_shape)
physical_num_dims = len(physical_shape)
- assert logical_num_dims >= physical_num_dims, \
- 'The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!'
+ assert (
+ logical_num_dims >= physical_num_dims
+ ), "The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!"
# track the dim and its broadcasting type
logical_dim_broadcast_info = {}
@@ -69,24 +73,25 @@ def get_broadcast_dim_info(logical_shape, physical_shape):
for i in range(logical_num_dims):
# get the trailing dim size
logical_dim_idx = logical_num_dims - i - 1
- phyiscal_dim_idx = physical_num_dims - i - 1
+ physical_dim_idx = physical_num_dims - i - 1
logical_dim_size = logical_shape[logical_dim_idx]
- if phyiscal_dim_idx >= 0:
- physical_dim_size = physical_shape[phyiscal_dim_idx]
+ if physical_dim_idx >= 0:
+ physical_dim_size = physical_shape[physical_dim_idx]
if physical_dim_size == logical_dim_size:
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.EQUAL
elif physical_dim_size == 1 and physical_dim_size != logical_dim_size:
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.MULTIPLE
else:
- logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING
+ logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDING
return logical_dim_broadcast_info
-def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size,
- physical_shape: torch.Size) -> ShardingSpec:
+def recover_sharding_spec_for_broadcast_shape(
+ logical_sharding_spec: ShardingSpec, logical_shape: torch.Size, physical_shape: torch.Size
+) -> ShardingSpec:
"""
This function computes the sharding spec for the physical shape of a broadcast tensor.
@@ -117,22 +122,25 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
for shape_dim, mesh_dim in logical_dim_partition.items():
logical_broadcast_type = logical_dim_broadcast_info[shape_dim]
- if logical_broadcast_type == BroadcastType.PADDDING or logical_broadcast_type == BroadcastType.MULTIPLE:
+ if logical_broadcast_type == BroadcastType.PADDING or logical_broadcast_type == BroadcastType.MULTIPLE:
removed_dims.extend(mesh_dim)
else:
# get the corresponding physical dim
physical_dim = physical_num_dims - (logical_num_dims - shape_dim)
physical_dim_partition[physical_dim] = mesh_dim
- physical_sharding_spec = ShardingSpec(device_mesh=logical_sharding_spec.device_mesh,
- entire_shape=physical_shape,
- dim_partition_dict=physical_dim_partition)
+ physical_sharding_spec = ShardingSpec(
+ device_mesh=logical_sharding_spec.device_mesh,
+ entire_shape=physical_shape,
+ dim_partition_dict=physical_dim_partition,
+ )
return physical_sharding_spec, removed_dims
-def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: OperationData,
- sharding_spec: ShardingSpec) -> CommAction:
+def comm_actions_for_oprands(
+ node: Node, removed_dims: List[int], op_data: OperationData, sharding_spec: ShardingSpec
+) -> CommAction:
"""
This method is used to generate communication actions for oprands which lose information
during convert logical shape to physical shape.
@@ -140,9 +148,11 @@ def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: Opera
if len(removed_dims) == 1:
# if list length is 1, extract element from list to avoid using flatten device mesh
removed_dims = removed_dims[0]
- comm_spec = CommSpec(comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
- sharding_spec=sharding_spec,
- logical_process_axis=removed_dims)
+ comm_spec = CommSpec(
+ comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
+ sharding_spec=sharding_spec,
+ logical_process_axis=removed_dims,
+ )
if op_data.type == OperationDataType.PARAM:
comm_type = CommType.HOOK
else:
@@ -151,7 +161,7 @@ def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: Opera
for index, arg in enumerate(node.args):
if op_data.name == str(arg):
arg_index = index
- assert arg_index >= 0, f'op_data should be an argument of node.'
+ assert arg_index >= 0, f"op_data should be an argument of node."
comm_action = CommAction(
comm_spec=comm_spec,
comm_type=comm_type,
diff --git a/colossalai/auto_parallel/tensor_shard/utils/factory.py b/colossalai/auto_parallel/tensor_shard/utils/factory.py
index 05331e56000110a982cc776a24eb81d45fceb825..aaca923a5eeee28b241b3bc7877ace321b312761 100644
--- a/colossalai/auto_parallel/tensor_shard/utils/factory.py
+++ b/colossalai/auto_parallel/tensor_shard/utils/factory.py
@@ -14,11 +14,12 @@ from colossalai.tensor.sharding_spec import ShardingSpec
from ..constants import INFINITY_COST
-__all__ = ['generate_sharding_spec', 'generate_resharding_costs']
+__all__ = ["generate_sharding_spec", "generate_resharding_costs"]
-def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
- dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
+def generate_sharding_spec(
+ input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh, dim_partition_dict: Dict[int, List[int]]
+) -> ShardingSpec:
"""
Generate the sharding spec of the tensor based on the given dim_partition_dict.
@@ -30,7 +31,7 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
"""
if isinstance(input_, Node):
- assert hasattr(input_, '_meta_data'), f'The given node has no attribte _meta_data'
+ assert hasattr(input_, "_meta_data"), f"The given node has no attribute _meta_data"
meta_tensor = input_._meta_data
assert meta_tensor is not None, "The given node's _meta_data attribute is None"
shape = meta_tensor.shape
@@ -38,24 +39,27 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
shape = input_.shape
else:
raise TypeError(
- f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.'
+ f"We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected."
)
for dim_index, sharding_index_list in dim_partition_dict.items():
sharding_list = [device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
sharding_size = reduce(operator.mul, sharding_list, 1)
- assert shape[
- dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.'
+ assert (
+ shape[dim_index] % sharding_size == 0
+ ), f"we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions."
sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
return sharding_spec
-def generate_resharding_costs(nodes: List[Node],
- sharding_specs: List[ShardingSpec],
- count_backward: Optional[bool] = True,
- dtype: Optional[torch.dtype] = None,
- index=None):
- '''
+def generate_resharding_costs(
+ nodes: List[Node],
+ sharding_specs: List[ShardingSpec],
+ count_backward: Optional[bool] = True,
+ dtype: Optional[torch.dtype] = None,
+ index=None,
+):
+ """
Compute the resharding costs with this specific strategy.
Argument:
@@ -63,7 +67,7 @@ def generate_resharding_costs(nodes: List[Node],
sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes.
count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference.
dtype (Optional[torch.dtype]): the data type for cost calculation, default is None.
- '''
+ """
# The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs = {}
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
@@ -76,38 +80,39 @@ def generate_resharding_costs(nodes: List[Node],
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
if not isinstance(input_sharding_spec, ShardingSpec):
- assert isinstance(input_sharding_spec, list), 'only ShardingSpec or List[ShardingSpec] is expected.'
+ assert isinstance(input_sharding_spec, list), "only ShardingSpec or List[ShardingSpec] is expected."
input_sharding_spec = input_sharding_spec[index]
- assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
+ assert isinstance(input_sharding_spec, ShardingSpec), f"The input node should NOT be a tuple of tensor."
try:
# compute the resharding cost
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
- input_sharding_spec, input_spec)
+ input_sharding_spec, input_spec
+ )
# we need multiply the size of elem dtype to get correct communication cost
resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes
except AssertionError as e:
- warnings.warn(f'{e}')
+ warnings.warn(f"{e}")
resharding_cost = INFINITY_COST
resharding_costs[input_node].append(resharding_cost)
return resharding_costs
def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_length_threshold: int = 20):
- '''
+ """
Find the largest repeat blocks in the graph, whose length is larger than the threshold.
Args:
gm (GraphModule): the graph module to be analyzed.
common_length_threshold (int): the threshold of the repeat block length.
- '''
+ """
# graph = gm.graph
def _process_args(args):
new_args = []
for arg in args:
- if hasattr(arg, '_meta_data'):
+ if hasattr(arg, "_meta_data"):
meta_data = arg._meta_data
else:
meta_data = arg
@@ -145,7 +150,7 @@ def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_lengt
return False
for index, node in enumerate(node_list):
- if node.op == 'call_module':
+ if node.op == "call_module":
target = node.target
submod = root_module.get_submodule(target)
submod_type = type(submod)
@@ -155,12 +160,12 @@ def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_lengt
new_args = _process_args(node.args)
- if node.op != 'get_attr':
+ if node.op != "get_attr":
hash_key = (node.op, target, *new_args)
else:
hash_key = (node.op,)
- setattr(node, 'hash_key', hash_key)
+ setattr(node, "hash_key", hash_key)
hash_value_to_node_dict = {}
@@ -179,7 +184,7 @@ def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_lengt
# the comparison will be triggered if a common node appears
if len(hash_value_to_node_dict[hash(node.hash_key)]) >= 2:
start_index_list = hash_value_to_node_dict[hash(node.hash_key)]
- check_block_list = [node_list[start:start + max_common_length] for start in start_index_list]
+ check_block_list = [node_list[start : start + max_common_length] for start in start_index_list]
common_label = True
if not _all_equal(check_block_list, _check_node_list_equal):
@@ -201,6 +206,6 @@ def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_lengt
# recover common subgraph from the index
common_blocks = []
for start in common_blocks_index:
- common_blocks.append(node_list[start:start + max_common_length])
+ common_blocks.append(node_list[start : start + max_common_length])
return common_blocks
diff --git a/colossalai/auto_parallel/tensor_shard/utils/misc.py b/colossalai/auto_parallel/tensor_shard/utils/misc.py
index 9e402dab757820c5d76ee6d1166de473c040784b..42ec2a8ee428880a6a86e50e2a8c0ae21a2a4b74 100644
--- a/colossalai/auto_parallel/tensor_shard/utils/misc.py
+++ b/colossalai/auto_parallel/tensor_shard/utils/misc.py
@@ -1,12 +1,12 @@
import functools
-from typing import Any, Callable, Dict, List, Tuple, Type, Union
+from typing import Any, Callable, Tuple, Type, Union
import torch
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
-__all__ = ['ignore_sharding_exception', 'pytree_map']
+__all__ = ["ignore_sharding_exception", "pytree_map"]
def ignore_sharding_exception(func):
@@ -46,31 +46,34 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
# make sure all dims are covered in sharding spec
sharding_len = len(sharding_spec.sharding_sequence)
tensor_num_dim = tensor.dim()
- num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0]
- num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1]
- assert sharding_len == tensor_num_dim, \
- f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).'
+ num_devices_in_col = sharding_spec.device_mesh.shape[0]
+ num_devices_in_row = sharding_spec.device_mesh.shape[1]
+ assert (
+ sharding_len == tensor_num_dim
+ ), f"The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape})."
# make sure the sharding is valid for each dim
for i in range(tensor_num_dim):
dim_size = tensor.shape[i]
dim_spec = sharding_spec.sharding_sequence[i]
- if str(dim_spec).startswith('S'):
- devices_str = str(dim_spec).lstrip('S')
+ if str(dim_spec).startswith("S"):
+ devices_str = str(dim_spec).lstrip("S")
num_devices = 1
- if '0' in devices_str:
+ if "0" in devices_str:
num_devices *= num_devices_in_col
- if '1' in devices_str:
+ if "1" in devices_str:
num_devices *= num_devices_in_row
- assert dim_size >= num_devices and dim_size % num_devices == 0, \
- f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
+ assert (
+ dim_size >= num_devices and dim_size % num_devices == 0
+ ), f"The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices."
# make sure the entire shape matches the physical tensor shape
- assert sharding_spec.entire_shape == tensor.shape, \
- f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}'
+ assert (
+ sharding_spec.entire_shape == tensor.shape
+ ), f"The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}"
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
diff --git a/colossalai/auto_parallel/tensor_shard/utils/reshape.py b/colossalai/auto_parallel/tensor_shard/utils/reshape.py
index a32a14bf7d577713ae2cb986ffbb42d87b0cabc1..329312ef797fe1afa649bc675425082faf495c62 100644
--- a/colossalai/auto_parallel/tensor_shard/utils/reshape.py
+++ b/colossalai/auto_parallel/tensor_shard/utils/reshape.py
@@ -6,12 +6,13 @@ import torch
class PreviousStatus(Enum):
"""
- This class shows the status of previous comparision.
+ This class shows the status of previous comparison.
"""
+
RESET = 0
- # ORIGIN means the dimension size of original tensor is larger in the previous comparision.
+ # ORIGIN means the dimension size of original tensor is larger in the previous comparison.
ORIGIN = 1
- # TGT means the dimension size of target tensor is larger in the previous comparision.
+ # TGT means the dimension size of target tensor is larger in the previous comparison.
TGT = 2
@@ -91,7 +92,7 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D
tgt_index += 1
if previous_label == PreviousStatus.TGT:
- # if the target dimension size is larger in the previous comparision, which means
+ # if the target dimension size is larger in the previous comparison, which means
# the origin dimension size has already accumulated larger than target dimension size, so
# we need to offload the origin dims and tgt dims into the reshape_mapping_dict.
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
@@ -111,7 +112,7 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D
origin_index += 1
if previous_label == PreviousStatus.ORIGIN:
- # if the origin element is larger in the previous comparision, which means
+ # if the origin element is larger in the previous comparison, which means
# the target element has already accumulated larger than origin element, so
# we need to offload the origin dims and tgt dims into the reshape_mapping_dict.
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
@@ -130,8 +131,9 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D
return reshape_mapping_dict
-def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]],
- reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> bool:
+def check_keep_sharding_status(
+ input_dim_partition_dict: Dict[int, List[int]], reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]
+) -> bool:
"""
This method is used to check whether the reshape operation could implement without converting
the input to fully replicated status.
@@ -139,7 +141,7 @@ def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]],
Rule:
For a sharded dimension of input tensor, if it is not the minimum element of the input tuple,
the function will return false.
- To illustrate this issue, there are two cases to analyse:
+ To illustrate this issue, there are two cases to analyze:
1. no sharded dims in the input tuple: we could do the reshape operation safely just as the normal
operation without distributed tensor.
2. sharded dims in the input tuple: the sharded dim must be the minimum element, then during shape
@@ -172,14 +174,16 @@ def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]],
return True
-def infer_output_dim_partition_dict(input_dim_partition_dict: Dict[int, List[int]],
- reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> Dict[Tuple[int], Tuple[int]]:
+def infer_output_dim_partition_dict(
+ input_dim_partition_dict: Dict[int, List[int]], reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]
+) -> Dict[Tuple[int], Tuple[int]]:
"""
This method is used to infer the output dim partition dict for a reshape operation,
given the input dim partition dict and reshape mapping dict.
"""
- assert check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict), \
- 'we only infer output dim partition dict for the reshape operation could keep sharding spec.'
+ assert check_keep_sharding_status(
+ input_dim_partition_dict, reshape_mapping_dict
+ ), "we only infer output dim partition dict for the reshape operation could keep sharding spec."
sharded_dims = list(input_dim_partition_dict.keys())
output_dim_partition_dict = {}
for input_dims, output_dims in reshape_mapping_dict.items():
diff --git a/colossalai/auto_parallel/tensor_shard/utils/sharding.py b/colossalai/auto_parallel/tensor_shard/utils/sharding.py
index e2ce59e0b5772679be11e960322e3110c500d6aa..b5386d599be4a62672c83c6f3daf6085e6062349 100644
--- a/colossalai/auto_parallel/tensor_shard/utils/sharding.py
+++ b/colossalai/auto_parallel/tensor_shard/utils/sharding.py
@@ -8,8 +8,11 @@ import torch
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = [
- 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
- 'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
+ "transpose_partition_dim",
+ "update_partition_dim",
+ "enumerate_all_possible_1d_sharding",
+ "enumerate_all_possible_2d_sharding",
+ "generate_sharding_size",
]
@@ -22,8 +25,7 @@ def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -
dim1 (int): the tensor dimension to switch
dim2 (int): the tensor dimension to switch
"""
- assert len(sharding_spec.entire_shape) >= 2, \
- 'The entire_shape of the sharding spec must have at least 2 dimensions'
+ assert len(sharding_spec.entire_shape) >= 2, "The entire_shape of the sharding spec must have at least 2 dimensions"
dim_partition_dict = sharding_spec.dim_partition_dict
# transpose the dim partition
@@ -45,10 +47,9 @@ def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -
return sharding_spec
-def update_partition_dim(sharding_spec: ShardingSpec,
- dim_mapping: Dict[int, int],
- physical_shape: torch.Size,
- inplace: bool = False):
+def update_partition_dim(
+ sharding_spec: ShardingSpec, dim_mapping: Dict[int, int], physical_shape: torch.Size, inplace: bool = False
+):
"""
This method is used to update the partition dim dict from the logical one to the physical one.
@@ -78,9 +79,9 @@ def update_partition_dim(sharding_spec: ShardingSpec,
new_dim_partition_dict[tensor_dim] = mesh_dims
# update sharding spec
- current_sharding_spec.__init__(device_mesh=sharding_spec.device_mesh,
- entire_shape=physical_shape,
- dim_partition_dict=new_dim_partition_dict)
+ current_sharding_spec.__init__(
+ device_mesh=sharding_spec.device_mesh, entire_shape=physical_shape, dim_partition_dict=new_dim_partition_dict
+ )
return current_sharding_spec
diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py
index d0a467254d7279c37031a755fa62a53fc4e0d9b9..9571fa2c17f074adfd4bff244c8ae8bcfcadb3c0 100644
--- a/colossalai/autochunk/autochunk_codegen.py
+++ b/colossalai/autochunk/autochunk_codegen.py
@@ -9,7 +9,18 @@ from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABL
AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta()
if AUTOCHUNK_AVAILABLE:
- from torch.fx.graph import CodeGen, PythonCode, _custom_builtins, _CustomBuiltin, _format_target, _is_from_torch, _Namespace, _origin_type_map, inplace_methods, magic_methods
+ from torch.fx.graph import (
+ CodeGen,
+ PythonCode,
+ _custom_builtins,
+ _CustomBuiltin,
+ _format_target,
+ _is_from_torch,
+ _Namespace,
+ _origin_type_map,
+ inplace_methods,
+ magic_methods,
+ )
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
@@ -40,7 +51,7 @@ def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) ->
return new_shape
-def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_ouput_dim: int, chunk_size=2) -> str:
+def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_output_dim: int, chunk_size=2) -> str:
"""
Generate chunk loop start
@@ -52,7 +63,7 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_oup
Args:
chunk_input (List[Node]): chunk input node
chunk_output (Node): chunk output node
- chunk_ouput_dim (int): chunk output node chunk dim
+ chunk_output_dim (int): chunk output node chunk dim
chunk_size (int): chunk size. Defaults to 2.
Returns:
@@ -64,23 +75,36 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_oup
for i in range(len(chunk_output)):
shape_str = str(list(get_node_shape(chunk_output[i])))
if get_node_name(chunk_output[i]) in ["split", "unbind"]:
- tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (shape_str, input_node.name,
- input_node.name)
- tensor_str = tensor_str * len(chunk_output[i].meta['tensor_meta'])
+ tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (
+ shape_str,
+ input_node.name,
+ input_node.name,
+ )
+ tensor_str = tensor_str * len(chunk_output[i].meta["tensor_meta"])
tensor_str = "[" + tensor_str[:-2] + "]"
context += "%s = %s; " % (chunk_output[i].name, tensor_str)
else:
- context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (chunk_output[i].name, shape_str,
- input_node.name, input_node.name)
+ context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (
+ chunk_output[i].name,
+ shape_str,
+ input_node.name,
+ input_node.name,
+ )
out_shape = get_node_shape(chunk_output[0])
- chunk_shape = out_shape[chunk_ouput_dim[0]]
+ chunk_shape = out_shape[chunk_output_dim[0]]
context += "chunk_size = %d\nfor chunk_idx in range(0, %d, chunk_size):\n" % (chunk_size, chunk_shape)
return context
-def _gen_loop_end(chunk_inputs: List[Node], chunk_non_compute_inputs: List[Node], node_list: List[Node],
- chunk_outputs_idx: int, chunk_outputs_non_tensor: List[Node], search_chunk: SearchChunk) -> str:
+def _gen_loop_end(
+ chunk_inputs: List[Node],
+ chunk_non_compute_inputs: List[Node],
+ node_list: List[Node],
+ chunk_outputs_idx: int,
+ chunk_outputs_non_tensor: List[Node],
+ search_chunk: SearchChunk,
+) -> str:
"""
Generate chunk loop end
@@ -148,8 +172,10 @@ def _replace_new_tensor_like_shape(
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
if get_node_shape(meta_node)[chunk_dim] != 1:
source_node = meta_node.args[0].args[0]
- if (source_node not in chunk_infos[region_idx]["node_chunk_dim"]
- or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None):
+ if (
+ source_node not in chunk_infos[region_idx]["node_chunk_dim"]
+ or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None
+ ):
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "chunk_idx", get_node_shape(node))
body[-1] = _replace_name(body[-1], node.args[0].name, node.args[0].name + chunk_slice)
return body
@@ -203,11 +229,12 @@ def _add_node_slice(
# outputs node
else:
if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]):
- chunk_slice = _gen_chunk_slice_dim(chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx",
- get_node_shape(chunk_node))
+ chunk_slice = _gen_chunk_slice_dim(
+ chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx", get_node_shape(chunk_node)
+ )
if get_node_name(chunk_node) in ["split", "unbind"]:
split_chunk_slice = ""
- for i in range(len(chunk_node.meta['tensor_meta'])):
+ for i in range(len(chunk_node.meta["tensor_meta"])):
split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice)
split_chunk_slice = split_chunk_slice[:-2]
body[-1] = _replace_name(body[-1], chunk_node.name, split_chunk_slice)
@@ -216,13 +243,15 @@ def _add_node_slice(
return body
-def emit_code_with_chunk(body: List[str],
- nodes: Iterable[Node],
- emit_node_func: Callable,
- delete_unused_value_func: Callable,
- search_chunk: SearchChunk,
- chunk_infos: List,
- eval_mem: bool = False):
+def emit_code_with_chunk(
+ body: List[str],
+ nodes: Iterable[Node],
+ emit_node_func: Callable,
+ delete_unused_value_func: Callable,
+ search_chunk: SearchChunk,
+ chunk_infos: List,
+ eval_mem: bool = False,
+):
"""
Emit code with chunk according to chunk_infos.
@@ -244,9 +273,9 @@ def emit_code_with_chunk(body: List[str],
chunk_ends = [i["region"][1] for i in chunk_infos]
# chunk inputs
- chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
- chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk
- chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
+ chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
+ chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk
+ chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
# chunk outputs
@@ -275,7 +304,8 @@ def emit_code_with_chunk(body: List[str],
chunk_outputs[region_idx],
chunk_outputs_dim[region_idx],
chunk_infos[region_idx]["chunk_size"],
- ))
+ )
+ )
if within_chunk_region:
emit_node_func(node, body)
@@ -294,7 +324,8 @@ def emit_code_with_chunk(body: List[str],
if eval_mem:
body.append(
" if chunk_idx == 0:\n print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
- % (node.name))
+ % (node.name)
+ )
else:
emit_node_func(node, body)
if node_idx not in chunk_inputs:
@@ -302,13 +333,21 @@ def emit_code_with_chunk(body: List[str],
if eval_mem:
body.append(
"print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
- % (node.name))
+ % (node.name)
+ )
# generate chunk region end
if node_idx in chunk_ends:
body.append(
- _gen_loop_end(chunk_inputs[region_idx], chunk_inputs_non_chunk[region_idx], node_list,
- chunk_ends[region_idx], chunk_outputs_non_tensor[region_idx], search_chunk))
+ _gen_loop_end(
+ chunk_inputs[region_idx],
+ chunk_inputs_non_chunk[region_idx],
+ node_list,
+ chunk_ends[region_idx],
+ chunk_outputs_non_tensor[region_idx],
+ search_chunk,
+ )
+ )
within_chunk_region = False
node_idx += 1
@@ -317,13 +356,14 @@ def emit_code_with_chunk(body: List[str],
if AUTOCHUNK_AVAILABLE:
class AutoChunkCodeGen(CodeGen):
-
- def __init__(self,
- meta_graph,
- max_memory: int = None,
- print_mem: bool = False,
- print_progress: bool = False,
- eval_mem: bool = False) -> None:
+ def __init__(
+ self,
+ meta_graph,
+ max_memory: int = None,
+ print_mem: bool = False,
+ print_progress: bool = False,
+ eval_mem: bool = False,
+ ) -> None:
super().__init__()
self.eval_mem = eval_mem
# find the chunk regions
@@ -349,7 +389,7 @@ if AUTOCHUNK_AVAILABLE:
Returns: the global name that should be used to reference 'obj' in generated source.
"""
- if (_is_from_torch(obj) and obj != torch.device): # to support registering torch.device
+ if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
@@ -402,7 +442,6 @@ if AUTOCHUNK_AVAILABLE:
return add_global(typename, o)
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
-
def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global.
if isinstance(arg, tuple) and hasattr(arg, "_fields"):
@@ -457,10 +496,10 @@ if AUTOCHUNK_AVAILABLE:
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
- maybe_type_annotation = ("" if node.type is None else f" : {type_repr(node.type)}")
+ maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
if node.op == "placeholder":
assert isinstance(node.target, str)
- maybe_default_arg = ("" if not node.args else f" = {repr(node.args[0])}")
+ maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
raw_name = node.target.replace("*", "")
if raw_name != repr(node):
@@ -470,42 +509,56 @@ if AUTOCHUNK_AVAILABLE:
assert isinstance(node.target, str)
body.append(
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
- f"({_format_args(node.args[1:], node.kwargs)})")
+ f"({_format_args(node.args[1:], node.kwargs)})"
+ )
return
elif node.op == "call_function":
assert callable(node.target)
# pretty print operators
- if (node.target.__module__ == "_operator" and node.target.__name__ in magic_methods):
+ if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
- body.append(f"{repr(node)}{maybe_type_annotation} = "
- f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}")
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = "
+ f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
+ )
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
- if (node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods):
- body.append(f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
- f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}")
+ if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods:
+ body.append(
+ f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
+ f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
+ )
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
- if (global_name == "getattr" and isinstance(node.args, tuple) and isinstance(node.args[1], str)
- and node.args[1].isidentifier() and len(node.args) == 2):
+ if (
+ global_name == "getattr"
+ and isinstance(node.args, tuple)
+ and isinstance(node.args[1], str)
+ and node.args[1].isidentifier()
+ and len(node.args) == 2
+ ):
body.append(
- f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}")
+ f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
+ )
return
body.append(
- f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})")
+ f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
+ )
if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name)
return
elif node.op == "call_module":
assert isinstance(node.target, str)
- body.append(f"{repr(node)}{maybe_type_annotation} = "
- f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})")
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = "
+ f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
+ )
return
elif node.op == "get_attr":
assert isinstance(node.target, str)
@@ -523,8 +576,9 @@ if AUTOCHUNK_AVAILABLE:
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
- emit_code_with_chunk(body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos,
- self.eval_mem)
+ emit_code_with_chunk(
+ body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos, self.eval_mem
+ )
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
diff --git a/colossalai/autochunk/estimate_memory.py b/colossalai/autochunk/estimate_memory.py
index 77bc2ef17bc3bc5faca5903ddd4dfc5a7653275a..a85ad429e261fa2d9de968ac34af53ea0a44dfe5 100644
--- a/colossalai/autochunk/estimate_memory.py
+++ b/colossalai/autochunk/estimate_memory.py
@@ -1,11 +1,8 @@
-import copy
-from typing import Any, Callable, Dict, Iterable, List, Tuple
+from typing import Dict, List
import torch
from torch.fx.node import Node
-from colossalai.fx.profiler import activation_size, parameter_size
-
from .utils import NodeMgr, get_node_shape, is_non_memory_node
@@ -62,12 +59,9 @@ class EstimateMemory(object):
delete_node_dict[node] = max(node_user_idx)
return delete_node_dict
- def _remove_deactive_node(self,
- user_idx: int,
- user: Node,
- active_nodes: List,
- delete_node_dict: List,
- kept_nodes: List = None) -> None:
+ def _remove_deactive_node(
+ self, user_idx: int, user: Node, active_nodes: List, delete_node_dict: List, kept_nodes: List = None
+ ) -> None:
"""
remove deactivate nodes from active nodes
"""
@@ -169,7 +163,7 @@ class EstimateMemory(object):
use_chunk = True if chunk_infos is not None else False
chunk_within = False
chunk_region_idx = None
- chunk_ratio = 1 # use it to estimate chunk mem
+ chunk_ratio = 1 # use it to estimate chunk mem
chunk_inputs_all = []
if use_chunk:
@@ -184,7 +178,6 @@ class EstimateMemory(object):
chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]
for idx, node in enumerate(node_mgr.get_node_list()):
-
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
if use_chunk and idx in chunk_starts:
chunk_within = True
@@ -193,8 +186,9 @@ class EstimateMemory(object):
# determine chunk ratio for current node
if chunk_within:
- chunk_ratio = self._get_chunk_ratio(node, chunk_node_dim[chunk_region_idx],
- chunk_sizes[chunk_region_idx])
+ chunk_ratio = self._get_chunk_ratio(
+ node, chunk_node_dim[chunk_region_idx], chunk_sizes[chunk_region_idx]
+ )
# add current node as active node
self._add_active_node(node, active_nodes, chunk_ratio)
@@ -222,7 +216,7 @@ class EstimateMemory(object):
# if node in chunk end nodes, restore chunk settings
if use_chunk and idx in chunk_ends:
- self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now
+ self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now
chunk_within = False
chunk_ratio = 1
chunk_region_idx = None
diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py
index 59645c80e8089d63a53abd8a27c7784f2f90cd8d..1c599049d9ebfceab3f2c5342b0f25cace31beaf 100644
--- a/colossalai/autochunk/search_chunk.py
+++ b/colossalai/autochunk/search_chunk.py
@@ -8,7 +8,7 @@ from .reorder_graph import ReorderGraph
from .select_chunk import SelectChunk
from .trace_flow import TraceFlow
from .trace_indice import TraceIndice
-from .utils import NodeMgr, get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
+from .utils import NodeMgr, get_logger, is_non_compute_node, is_non_compute_node_except_placeholder
class SearchChunk(object):
@@ -121,8 +121,10 @@ class SearchChunk(object):
# check if peak node already in chunk info
if chunk_regions is not None:
for i in chunk_regions:
- if i["region"][0] < peak_region[0] <= i["region"][1] or \
- i["region"][0] < peak_region[1] <= i["region"][1]:
+ if (
+ i["region"][0] < peak_region[0] <= i["region"][1]
+ or i["region"][0] < peak_region[1] <= i["region"][1]
+ ):
return None
active_node_num = [len(i) for i in active_node]
@@ -146,9 +148,9 @@ class SearchChunk(object):
region = i["region"]
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
return None
- elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
+ elif region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]:
chunk_region_start = region[1] + 1
- elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
+ elif region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]:
chunk_region_end = region[0] - 1
return chunk_region_start, chunk_region_end
@@ -171,7 +173,7 @@ class SearchChunk(object):
chunk_infos: possible regions found
"""
start_traces = input_trace[start_idx]
- if len(start_traces) > 1: # TODO need to be removed
+ if len(start_traces) > 1: # TODO need to be removed
return []
end_trace = output_trace[end_idx]
end_node = self.node_mgr.get_node_by_idx(end_idx)
@@ -180,8 +182,9 @@ class SearchChunk(object):
for end_dim, _ in enumerate(end_trace["indice"]):
for start_node, start_trace in start_traces.items():
for start_dim, _ in enumerate(start_trace["indice"]):
- if not self.trace_flow.check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim,
- end_idx):
+ if not self.trace_flow.check_region_start_end(
+ start_node, start_dim, start_idx, end_node, end_dim, end_idx
+ ):
continue
# flow search
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
@@ -203,7 +206,7 @@ class SearchChunk(object):
"""
possible_chunk_region = []
output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
- input_trace = [] # trace of a node's input nodes
+ input_trace = [] # trace of a node's input nodes
for _, n in enumerate(self.node_mgr.get_node_list()):
cur_trace = {}
for arg in n.args:
@@ -215,7 +218,8 @@ class SearchChunk(object):
for end_idx in range(peak_region[1], max_chunk_region[1] + 1):
# skip non compute nodes
if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node(
- self.node_mgr.get_node_by_idx(end_idx)):
+ self.node_mgr.get_node_by_idx(end_idx)
+ ):
continue
# select free dim
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
@@ -279,15 +283,18 @@ class SearchChunk(object):
chunk_infos.append(chunk_info)
mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem(
- self.node_mgr.get_node_list(), chunk_infos)
+ self.node_mgr.get_node_list(), chunk_infos
+ )
if self.print_progress:
- get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
- (len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]))
+ get_logger().info(
+ "AutoChunk find chunk region %d = (%d, %d)"
+ % (len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1])
+ )
if self.print_mem:
self.print_mem = False
- self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(),
- chunk_infos,
- print_mem=True)
+ self.estimate_memory.estimate_chunk_inference_mem(
+ self.node_mgr.get_node_list(), chunk_infos, print_mem=True
+ )
return chunk_infos
diff --git a/colossalai/autochunk/select_chunk.py b/colossalai/autochunk/select_chunk.py
index 94a29bfd56911eb9df749f284ae64fbd1b5d7a18..8a60ba681f70192a21e2f03ffdc656f6916dbd3d 100644
--- a/colossalai/autochunk/select_chunk.py
+++ b/colossalai/autochunk/select_chunk.py
@@ -5,7 +5,6 @@ from .utils import NodeMgr, is_non_compute_node
class SelectChunk(object):
-
def __init__(
self,
trace_indice: TraceIndice,
@@ -20,7 +19,7 @@ class SelectChunk(object):
self.node_mgr = node_mgr
if max_memory is not None:
self.stratge = "fit_memory"
- self.max_memory = max_memory # MB
+ self.max_memory = max_memory # MB
else:
self.stratge = "min_memory"
@@ -57,16 +56,18 @@ class SelectChunk(object):
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
- cur_chunk_region_peak = cur_mem[cur_region["region"][0]:cur_region["region"][1] + 1]
+ cur_chunk_region_peak = cur_mem[cur_region["region"][0] : cur_region["region"][1] + 1]
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
if cur_chunk_region_max_peak < self.max_memory:
- regions_dict.append({
- "chunk_info": region,
- "chunk_max_mem": cur_chunk_region_max_peak,
- "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
- "reorder_chunk_info": cur_region,
- "reorder_node_list": cur_node_list,
- })
+ regions_dict.append(
+ {
+ "chunk_info": region,
+ "chunk_max_mem": cur_chunk_region_max_peak,
+ "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
+ "reorder_chunk_info": cur_region,
+ "reorder_node_list": cur_node_list,
+ }
+ )
# no region found
if len(regions_dict) == 0:
raise RuntimeError("Search failed. Try a larger memory threshold.")
@@ -90,13 +91,15 @@ class SelectChunk(object):
chunk_size *= 2
reorder_chunk_info["chunk_size"] = chunk_size
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
- cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
- cur_chunk_infos)[0]
- cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0]:reorder_chunk_info["region"][1] + 1])
+ cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
+ chunk_region_dict["reorder_node_list"], cur_chunk_infos
+ )[0]
+ cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + 1])
# search exact size
chunk_info = chunk_region_dict["chunk_info"]
- chunk_info["chunk_size"] = self._chunk_size_binary_search(chunk_size // 2, chunk_size, chunk_region_dict,
- chunk_infos)
+ chunk_info["chunk_size"] = self._chunk_size_binary_search(
+ chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos
+ )
return chunk_info
def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos):
@@ -109,9 +112,10 @@ class SelectChunk(object):
mid = int((left + right) / 2 + 0.5)
chunk_info["chunk_size"] = mid
cur_chunk_infos = chunk_infos + [chunk_info]
- cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
- cur_chunk_infos)[0]
- cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0]:chunk_info["region"][1] + 1])
+ cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
+ chunk_region_dict["reorder_node_list"], cur_chunk_infos
+ )[0]
+ cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1])
if cur_chunk_max_mem >= self.max_memory:
right = mid - gap
else:
@@ -139,8 +143,10 @@ class SelectChunk(object):
return None
# get max possible chunk region
- max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]),
- max([i["region"][1] for i in possible_chunk_regions]))
+ max_possible_chunk_region = (
+ min([i["region"][0] for i in possible_chunk_regions]),
+ max([i["region"][1] for i in possible_chunk_regions]),
+ )
# get mem for chunk region
regions_dict_list = []
@@ -149,15 +155,17 @@ class SelectChunk(object):
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
- cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
+ cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0] : max_possible_chunk_region[1] + 1]
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
- regions_dict_list.append({
- "chunk_info": region,
- "chunk_max_mem": cur_chunk_region_max_peak,
- "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
- "reorder_chunk_info": cur_region,
- "reorder_node_list": cur_node_list,
- })
+ regions_dict_list.append(
+ {
+ "chunk_info": region,
+ "chunk_max_mem": cur_chunk_region_max_peak,
+ "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
+ "reorder_chunk_info": cur_region,
+ "reorder_node_list": cur_node_list,
+ }
+ )
# select the min mem
chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict_list]
@@ -175,7 +183,9 @@ class SelectChunk(object):
return False
for i in chunk_infos:
region = i["region"]
- if not ((chunk_region_start > region[1] and chunk_region_end > region[1]) or
- (chunk_region_start < region[0] and chunk_region_end < region[0])):
+ if not (
+ (chunk_region_start > region[1] and chunk_region_end > region[1])
+ or (chunk_region_start < region[0] and chunk_region_end < region[0])
+ ):
return False
return True
diff --git a/colossalai/autochunk/trace_flow.py b/colossalai/autochunk/trace_flow.py
index db25267e9b4200414c3cb89afaacd529a1534e06..8b36c99bbadd0f53e7148d2880e3016473cdebea 100644
--- a/colossalai/autochunk/trace_flow.py
+++ b/colossalai/autochunk/trace_flow.py
@@ -16,7 +16,6 @@ from .utils import (
class TraceFlow(object):
-
def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None:
self.trace_indice = trace_indice
self.node_mgr = node_mgr
@@ -64,7 +63,7 @@ class TraceFlow(object):
return False
return True
- def _assgin_single_node_flow(
+ def _assign_single_node_flow(
self,
arg_node: Node,
start_idx: int,
@@ -151,7 +150,7 @@ class TraceFlow(object):
return True
def _get_all_node_info(self, end_dim, start_idx, end_idx):
- cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node
+ cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
while len(cur_node_list) > 0:
@@ -177,7 +176,7 @@ class TraceFlow(object):
if get_node_shape(arg) is None:
continue
arg_list.append(arg)
- flow_flag = self._assgin_single_node_flow(
+ flow_flag = self._assign_single_node_flow(
arg,
start_idx,
end_idx,
@@ -266,7 +265,7 @@ class TraceFlow(object):
maybe_prepose_nodes.sort(
key=lambda x: self.node_mgr.find_node_idx(x),
reverse=True,
- ) # from last node to first node
+ ) # from last node to first node
prepose_nodes = []
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
while len(maybe_prepose_nodes) > 0:
@@ -315,7 +314,7 @@ class TraceFlow(object):
chunk_info["args"]["prepose_nodes"] = prepose_nodes
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
- # we need to log input nodes to avoid deleteing them in the loop
+ # we need to log input nodes to avoid deleting them in the loop
chunk_node_list = self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
# also need to get some prepose node's arg out of non_chunk_inputs
for n in chunk_info["args"]["prepose_nodes"]:
@@ -328,7 +327,8 @@ class TraceFlow(object):
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
inputs, outputs = find_chunk_compute_input_and_output_nodes(
- self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1))
+ self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
+ )
# get every node's chunk dim and fix dim
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
@@ -366,13 +366,14 @@ class TraceFlow(object):
# find non chunk inputs
chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx)
- # reassgin reshape size, some size may have changed due to chunk
- chunk_info = self._reassgin_reshape_size(chunk_info)
+ # reassign reshape size, some size may have changed due to chunk
+ chunk_info = self._reassign_reshape_size(chunk_info)
return chunk_info
- def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int,
- chunk_info: Dict):
+ def _get_other_output_info(
+ self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int, chunk_info: Dict
+ ):
start_node = self.node_mgr.get_node_by_idx(start_idx)
# loop all outputs
for output in outputs:
@@ -384,8 +385,8 @@ class TraceFlow(object):
# skip non tensor
if get_node_shape(output) is None:
# log shape tensor
- if len(output.meta['fwd_out']) > 0 and isinstance(output.meta['fwd_out'][0], int):
- chunk_info["outputs_non_tensor"][output] = str(output.meta['fwd_out'])
+ if len(output.meta["fwd_out"]) > 0 and isinstance(output.meta["fwd_out"][0], int):
+ chunk_info["outputs_non_tensor"][output] = str(output.meta["fwd_out"])
continue
# loop every dim of outputs, try to find a legal one
for output_dim in range(len(get_node_shape(output))):
@@ -421,17 +422,18 @@ class TraceFlow(object):
for k, v in new_all_node_info.items():
if k in chunk_info["node_chunk_dim"]:
chunk_info["node_chunk_dim"][k]["fix_dim"] = list(
- set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"]))
+ set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"])
+ )
else:
chunk_info["node_chunk_dim"][k] = v
chunk_info["outputs"].append(output)
chunk_info["outputs_dim"].append(output_dim)
return True
- def _reassgin_reshape_size(self, chunk_info):
+ def _reassign_reshape_size(self, chunk_info):
"""
Some shape args in reshape may have changed due to chunk
- reassgin those changed shape
+ reassign those changed shape
"""
chunk_region = chunk_info["region"]
reshape_size = {}
@@ -443,8 +445,11 @@ class TraceFlow(object):
if node.args[0] in chunk_info["inputs_non_chunk"]:
continue
reshape_args = flat_list(node.args[1:])
- if len(reshape_args) == 1 and get_node_shape(reshape_args[0]) is None and len(
- reshape_args[0].meta['fwd_out']) > 1:
+ if (
+ len(reshape_args) == 1
+ and get_node_shape(reshape_args[0]) is None
+ and len(reshape_args[0].meta["fwd_out"]) > 1
+ ):
continue
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
new_shape = ""
@@ -462,16 +467,17 @@ class TraceFlow(object):
chunk_info["reshape_size"] = reshape_size
return chunk_info
- def check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
- end_idx: int) -> bool:
+ def check_region_start_end(
+ self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int, end_idx: int
+ ) -> bool:
"""
check if region start and end is legal
"""
# dim cannot be None
- if (get_node_shape(end_node) is None or get_node_shape(start_node) is None):
+ if get_node_shape(end_node) is None or get_node_shape(start_node) is None:
return False
# dim size cannot be 1
- if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
+ if get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1:
return False
# must have users
if len(end_node.users) == 0:
diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py
index c7fce4c8bee1f2886d63c9ca6bb35c332cb293d8..378c54acf782090f11bbbe7991e4bab317c1a6f4 100644
--- a/colossalai/autochunk/trace_indice.py
+++ b/colossalai/autochunk/trace_indice.py
@@ -1,5 +1,5 @@
import copy
-from typing import Dict, List, Tuple
+from typing import Dict, List
from torch.fx.node import Node
@@ -18,7 +18,7 @@ class TraceIndice(object):
dim(x1)=dim(x2)=dim(x3)=[a, b, c]
This class will record every node's dims' indice, compute and source.
- Attibutes:
+ Attributes:
node_list (List)
indice_trace_list (List): [{"indice": [...], "compute": [...], "source": [...]}, {...}]
indice_view_list (Dict): not used for now
@@ -397,7 +397,7 @@ class TraceIndice(object):
input_node = node.args[0]
assert len(get_node_shape(input_node)) == 4
- # assgin index
+ # assign index
self._assign_indice_as_input(node, node_idx, input_node)
self._del_dim(node_idx, 1)
self._add_dim(node_idx, 1)
@@ -412,10 +412,10 @@ class TraceIndice(object):
node_idx (int)
"""
# get conv input
- assert node.kwargs['size'] is None
+ assert node.kwargs["size"] is None
assert len(get_node_shape(node)) == 4
- # assgin index
+ # assign index
self._assign_indice_as_input(node, node_idx)
self._mark_computation(node, node_idx, [-1, -2])
@@ -461,7 +461,7 @@ class TraceIndice(object):
nodes_in.append(node_in)
self._inherit_more_indice_from_node_with_exclude(node_in, node)
- def _assgin_no_change_indice(self, node, idx):
+ def _assign_no_change_indice(self, node, idx):
self._assign_indice_as_input(node, idx)
for node_in in node.args:
if type(node_in) == type(node):
@@ -792,7 +792,7 @@ class TraceIndice(object):
self._add_dim(node_idx, i)
dim_from.reverse()
- # inheirt indice from current node
+ # inherit indice from current node
if len(dim_from) != 0 and len(dim_to) != 0:
if dim_diff == 1:
if origin_shape[dim_from[0]] == 1:
@@ -826,7 +826,7 @@ class TraceIndice(object):
# clear compute
for dim_compute in trace["compute"]:
for i in range(len(dim_compute) - 1, -1, -1):
- if (dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes):
+ if dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes:
dim_compute.pop(i)
continue
# clear source
@@ -852,7 +852,7 @@ class TraceIndice(object):
elif "split" == node_name:
self._assign_split_indice(node, idx)
elif any(i == node_name for i in ["to", "contiguous", "clone", "type", "float"]):
- self._assgin_no_change_indice(node, idx)
+ self._assign_no_change_indice(node, idx)
elif "new_ones" == node_name:
self._assign_all_indice(node, idx)
elif "flatten" == node_name:
@@ -876,10 +876,24 @@ class TraceIndice(object):
self._assign_matmul_indice(node, idx)
elif "softmax" == node_name:
self._assign_softmax_indice(node, idx)
- elif any(n == node_name for n in [
- "mul", "add", "sigmoid", "relu", "sub", "truediv", "pow", "dropout", "where", "tanh", "exp",
- "sin", "cos"
- ]):
+ elif any(
+ n == node_name
+ for n in [
+ "mul",
+ "add",
+ "sigmoid",
+ "relu",
+ "sub",
+ "truediv",
+ "pow",
+ "dropout",
+ "where",
+ "tanh",
+ "exp",
+ "sin",
+ "cos",
+ ]
+ ):
self._assign_elementwise_indice(node, idx)
elif "einsum" == node_name:
self._assign_einsum_indice(node, idx)
@@ -914,13 +928,13 @@ class TraceIndice(object):
elif "conv2d" == node_name:
self._assign_conv2d_indice(node, idx)
elif "identity" == node_name:
- self._assgin_no_change_indice(node, idx)
+ self._assign_no_change_indice(node, idx)
elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu", "gelu"]):
self._assign_elementwise_indice(node, idx)
else:
raise NotImplementedError(node_name, "module not implemented yet!")
elif node.op == "get_attr":
- self._assign_all_indice(node, idx) # get param
+ self._assign_all_indice(node, idx) # get param
elif node.op == "output":
continue
else:
diff --git a/colossalai/autochunk/utils.py b/colossalai/autochunk/utils.py
index 064baa047155ac399b14ae0fcdb044db1125d70b..f6f803a5ce0a80f337443faaa7ffdea741688800 100644
--- a/colossalai/autochunk/utils.py
+++ b/colossalai/autochunk/utils.py
@@ -1,4 +1,4 @@
-from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
+from typing import Any, Dict, List, Union
from torch.fx.node import Node
@@ -10,7 +10,6 @@ logger = get_dist_logger()
class NodeMgr(object):
-
def __init__(self, nodes_list: List[Node]) -> None:
self._node_list = nodes_list
self._node_dict = {}
@@ -174,16 +173,22 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]) -> Union[List,
# we treat that input node as the input of the checkpoint function
for node in nodes:
for input_node in node._input_nodes.keys():
- if (input_node not in nodes and input_node not in input_nodes
- and not is_non_compute_node_except_placeholder(input_node)):
+ if (
+ input_node not in nodes
+ and input_node not in input_nodes
+ and not is_non_compute_node_except_placeholder(input_node)
+ ):
input_nodes.append(input_node)
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
for node in nodes:
for output_node in node.users.keys():
- if (output_node not in nodes and node not in output_nodes
- and not is_non_compute_node_except_placeholder_output(output_node)):
+ if (
+ output_node not in nodes
+ and node not in output_nodes
+ and not is_non_compute_node_except_placeholder_output(output_node)
+ ):
output_nodes.append(node)
return input_nodes, output_nodes
@@ -238,7 +243,10 @@ def find_tensor_shape_node(node_list: List[Node]) -> List[Node]:
for node in node_list:
if get_node_shape(node) is not None:
out.append(node)
- elif len(node.meta['fwd_out']) > 0 and isinstance(node.meta['fwd_out'], list) and isinstance(
- node.meta['fwd_out'][0], int):
+ elif (
+ len(node.meta["fwd_out"]) > 0
+ and isinstance(node.meta["fwd_out"], list)
+ and isinstance(node.meta["fwd_out"][0], int)
+ ):
out.append(node)
return out
diff --git a/colossalai/booster/accelerator.py b/colossalai/booster/accelerator.py
index fc2c4a40068b50cb49db8a3f9c33f22b8f307966..92990907bc2e5d24c67c3864682195ce22adca89 100644
--- a/colossalai/booster/accelerator.py
+++ b/colossalai/booster/accelerator.py
@@ -1,12 +1,11 @@
import torch
import torch.nn as nn
-__all__ = ['Accelerator']
+__all__ = ["Accelerator"]
_supported_devices = [
- 'cpu',
- 'cuda',
-
+ "cpu",
+ "cuda",
# To be supported
# 'xpu',
# 'npu',
@@ -25,21 +24,22 @@ class Accelerator:
def __init__(self, device: str):
self.device = device
- assert self.device in _supported_devices, f"Device {self.device} is not supported yet, supported devices include {_supported_devices}"
+ assert (
+ self.device in _supported_devices
+ ), f"Device {self.device} is not supported yet, supported devices include {_supported_devices}"
def bind(self):
"""
Set the default device for the current process.
"""
- if self.device == 'cpu':
+ if self.device == "cpu":
pass
- elif self.device == 'cuda':
+ elif self.device == "cuda":
# TODO(FrankLeeeee): use global environment to check if it is a dist job
# if is_distributed:
# local_rank = EnvTable().get_local_rank()
# torch.cuda.set_device(torch.device(f'cuda:{local_rank}'))
- torch.cuda.set_device(torch.device('cuda'))
- pass
+ torch.cuda.set_device(torch.device("cuda"))
else:
raise ValueError(f"Device {self.device} is not supported yet")
diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py
index c14e602deaf5ce60808a87bbbd238c9419b9d502..d73bc5babd8000689576e6f291bb6db82b4287d7 100644
--- a/colossalai/booster/booster.py
+++ b/colossalai/booster/booster.py
@@ -1,6 +1,6 @@
import warnings
from contextlib import contextmanager
-from typing import Callable, Iterator, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, Iterator, List, Optional, Union
import torch
import torch.nn as nn
@@ -8,13 +8,16 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
+import colossalai.interface.pretrained as pretrained_utils
from colossalai.checkpoint_io import GeneralCheckpointIO
+from colossalai.interface import ModelWrapper, OptimizerWrapper
from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory
from .plugin import Plugin
+from .plugin.pp_plugin_base import PipelinePluginBase
-__all__ = ['Booster']
+__all__ = ["Booster"]
class Booster:
@@ -22,56 +25,67 @@ class Booster:
Booster is a high-level API for training neural networks. It provides a unified interface for
training with different precision, accelerator, and plugin.
- Examples:
- >>> colossalai.launch(...)
- >>> plugin = GeminiPlugin(stage=3, ...)
- >>> booster = Booster(precision='fp16', plugin=plugin)
- >>>
- >>> model = GPT2()
- >>> optimizer = Adam(model.parameters())
- >>> dataloader = Dataloader(Dataset)
- >>> lr_scheduler = LinearWarmupScheduler()
- >>> criterion = GPTLMLoss()
- >>>
- >>> model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader)
- >>>
- >>> for epoch in range(max_epochs):
- >>> for input_ids, attention_mask in dataloader:
- >>> outputs = model(input_ids, attention_mask)
- >>> loss = criterion(outputs.logits, input_ids)
- >>> booster.backward(loss, optimizer)
- >>> optimizer.step()
- >>> lr_scheduler.step()
- >>> optimizer.zero_grad()
+ ```python
+ # Following is pseudocode
+
+ colossalai.launch(...)
+ plugin = GeminiPlugin(...)
+ booster = Booster(precision='fp16', plugin=plugin)
+
+ model = GPT2()
+ optimizer = HybridAdam(model.parameters())
+ dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
+ lr_scheduler = LinearWarmupScheduler()
+ criterion = GPTLMLoss()
+
+ model, optimizer, criterion, dataloader, lr_scheduler = booster.boost(model, optimizer, criterion, dataloader, lr_scheduler)
+
+ for epoch in range(max_epochs):
+ for input_ids, attention_mask in dataloader:
+ outputs = model(input_ids.cuda(), attention_mask.cuda())
+ loss = criterion(outputs.logits, input_ids)
+ booster.backward(loss, optimizer)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+ ```
Args:
- device (str or torch.device): The device to run the training. Default: 'cuda'.
+ device (str or torch.device): The device to run the training. Default: None.
+ If plugin is not used or plugin doesn't control the device,
+ this argument will be set as training device ('cuda' will be used if argument is None).
mixed_precision (str or MixedPrecision): The mixed precision to run the training. Default: None.
If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'.
'fp16' would use PyTorch AMP while `fp16_apex` would use Nvidia Apex.
plugin (Plugin): The plugin to run the training. Default: None.
"""
- def __init__(self,
- device: str = 'cuda',
- mixed_precision: Union[MixedPrecision, str] = None,
- plugin: Optional[Plugin] = None) -> None:
+ def __init__(
+ self,
+ device: Optional[str] = None,
+ mixed_precision: Optional[Union[MixedPrecision, str]] = None,
+ plugin: Optional[Plugin] = None,
+ ) -> None:
if plugin is not None:
assert isinstance(
- plugin, Plugin), f'Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}.'
+ plugin, Plugin
+ ), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}."
self.plugin = plugin
# set accelerator
if self.plugin and self.plugin.control_device():
self.accelerator = None
- warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
+ if device is not None:
+ warnings.warn("The plugin will control the accelerator, so the device argument will be ignored.")
else:
+ device = device or "cuda"
self.accelerator = Accelerator(device)
# set precision
if self.plugin and self.plugin.control_precision():
- warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
+ if mixed_precision is not None:
+ warnings.warn("The plugin will control the precision, so the mixed_precision argument will be ignored.")
self.mixed_precision = None
elif mixed_precision is None:
self.mixed_precision = None
@@ -85,7 +99,7 @@ class Booster:
self.mixed_precision = mixed_precision
else:
raise ValueError(
- f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.'
+ f"Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}."
)
if self.plugin is not None and self.plugin.control_checkpoint_io():
@@ -96,79 +110,216 @@ class Booster:
def boost(
self,
model: nn.Module,
- optimizer: Optimizer,
- criterion: Callable = None,
- dataloader: DataLoader = None,
- lr_scheduler: LRScheduler = None,
+ optimizer: Optional[Optimizer] = None,
+ criterion: Optional[Callable] = None,
+ dataloader: Optional[DataLoader] = None,
+ lr_scheduler: Optional[LRScheduler] = None,
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
"""
- Boost the model, optimizer, criterion, lr_scheduler, and dataloader.
+ Wrap and inject features to the passed in model, optimizer, criterion, lr_scheduler, and dataloader.
Args:
- model (nn.Module): The model to be boosted.
- optimizer (Optimizer): The optimizer to be boosted.
- criterion (Callable): The criterion to be boosted.
- dataloader (DataLoader): The dataloader to be boosted.
- lr_scheduler (LRScheduler): The lr_scheduler to be boosted.
+ model (nn.Module): Convert model into a wrapped model for distributive training.
+ The model might be decorated or partitioned by plugin's strategy after execution of this method.
+ optimizer (Optimizer, optional): Convert optimizer into a wrapped optimizer for distributive training.
+ The optimizer's param groups or states might be decorated or partitioned by plugin's strategy after execution of this method. Defaults to None.
+ criterion (Callable, optional): The function that calculates loss. Defaults to None.
+ dataloader (DataLoader, optional): The prepared dataloader for training. Defaults to None.
+ lr_scheduler (LRScheduler, optional): The learning scheduler for training. Defaults to None.
+
+ Returns:
+ List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: The list of boosted input arguments.
"""
# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
# TODO(FrankLeeeee): consider multi-dataloader case
+ pretrained_path = pretrained_utils.get_pretrained_path(model)
# transform model for mixed precision
if self.plugin:
model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure(
- model, optimizer, criterion, dataloader, lr_scheduler)
+ model, optimizer, criterion, dataloader, lr_scheduler
+ )
if self.plugin and not self.plugin.control_device():
# transform model for accelerator
- model = self.accelerator.configure(model)
+ model = self.accelerator.configure_model(model)
if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()):
# transform model for mixed precision
# when mixed_precision is specified and the plugin is not given or does not control the precision
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)
+ if pretrained_path:
+ self.load_model(model, pretrained_path)
+ # clear pretrained path attr
+ orig_model = model.unwrap() if isinstance(model, ModelWrapper) else model
+ pretrained_utils.set_pretrained_path(orig_model, None)
+
return model, optimizer, criterion, dataloader, lr_scheduler
def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
- # TODO: implement this method with plugin
+ """Execution of backward during training step.
+
+ Args:
+ loss (torch.Tensor): The loss for backpropagation.
+ optimizer (Optimizer): The optimizer to be updated.
+ """
+ # TODO(frank lee): implement this method with plugin
optimizer.backward(loss)
- def execute_pipeline(self,
- data_iter: Iterator,
- model: nn.Module,
- criterion: Callable[[torch.Tensor], torch.Tensor],
- optimizer: Optimizer,
- return_loss: bool = True,
- return_outputs: bool = False) -> Tuple[Optional[torch.Tensor], ...]:
- # TODO: implement this method
- # run pipeline forward backward pass
- # return loss or outputs if needed
- pass
-
- def no_sync(self, model: nn.Module) -> contextmanager:
- assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.'
- assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
- return self.plugin.no_sync(model)
-
- def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
+ def execute_pipeline(
+ self,
+ data_iter: Iterator,
+ model: nn.Module,
+ criterion: Callable[[Any, Any], torch.Tensor],
+ optimizer: Optional[Optimizer] = None,
+ return_loss: bool = True,
+ return_outputs: bool = False,
+ ) -> Dict[str, Any]:
+ """
+ Execute forward & backward when utilizing pipeline parallel.
+ Return loss or Huggingface style model outputs if needed.
+
+ Warning: This function is tailored for the scenario of pipeline parallel.
+ As a result, please don't do the forward/backward pass in the conventional way (model(input)/loss.backward())
+ when doing pipeline parallel training with booster, which will cause unexpected errors.
+
+ Args:
+ data_iter(Iterator): The iterator for getting the next batch of data. Usually there are two ways to obtain this argument:
+ 1. wrap the dataloader to iterator through: iter(dataloader)
+ 2. get the next batch from dataloader, and wrap this batch to iterator: iter([batch])
+ model (nn.Module): The model to execute forward/backward, it should be a model wrapped by a plugin that supports pipeline.
+ criterion: (Callable[[Any, Any], torch.Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
+ 'lambda y, x: loss_fn(y)' can turn a normal loss function into a valid two-argument criterion here.
+ optimizer (Optimizer, optional): The optimizer for execution of backward. Can be None when only doing forward (i.e. evaluation). Defaults to None.
+ return_loss (bool, optional): Whether to return loss in the dict returned by this method. Defaults to True.
+ return_output (bool, optional): Whether to return Huggingface style model outputs in the dict returned by this method. Defaults to False.
+
+ Returns:
+ Dict[str, Any]: Output dict in the form of {'loss': ..., 'outputs': ...}.
+ ret_dict['loss'] is the loss of forward if return_loss is set to True, else None.
+ ret_dict['outputs'] is the Huggingface style model outputs during forward if return_output is set to True, else None.
+ """
+ assert isinstance(
+ self.plugin, PipelinePluginBase
+ ), f"The plugin {self.plugin.__class__.__name__} does not support pipeline."
+ return self.plugin.execute_pipeline(data_iter, model, criterion, optimizer, return_loss, return_outputs)
+
+ def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager:
+ """Context manager to disable gradient synchronization across DP process groups.
+ Support torch DDP and Low Level ZeRO-1 for now.
+
+ Args:
+ model (nn.Module): The model to be disabled gradient synchronization, for DDP
+ optimizer (OptimizerWrapper): The optimizer to be disabled gradient synchronization, for ZeRO1-1
+
+ Returns:
+ contextmanager: Context to disable gradient synchronization.
+ """
+ assert (
+ self.plugin is not None
+ ), f"no_sync is only enabled when a plugin is provided and the plugin supports no_sync."
+ assert self.plugin.support_no_sync(), f"The plugin {self.plugin.__class__.__name__} does not support no_sync."
+ return self.plugin.no_sync(model, optimizer)
+
+ def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
+ """Load model from checkpoint.
+
+ Args:
+ model (nn.Module or ModelWrapper): A model boosted by Booster.
+ checkpoint (str): Path to the checkpoint. It must be a local path.
+ It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
+ strict (bool, optional): whether to strictly enforce that the keys
+ in :attr:`state_dict` match the keys returned by this module's
+ :meth:`~torch.nn.Module.state_dict` function. Defaults to True.
+ """
self.checkpoint_io.load_model(model, checkpoint, strict)
- def save_model(self,
- model: nn.Module,
- checkpoint: str,
- prefix: str = None,
- shard: bool = False,
- size_per_shard: int = 1024):
- self.checkpoint_io.save_model(model, checkpoint, prefix, shard, size_per_shard)
+ def save_model(
+ self,
+ model: Union[nn.Module, ModelWrapper],
+ checkpoint: str,
+ shard: bool = False,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024,
+ use_safetensors: bool = False,
+ ) -> None:
+ """Save model to checkpoint.
+
+ Args:
+ model (nn.Module or ModelWrapper): A model boosted by Booster.
+ checkpoint (str): Path to the checkpoint. It must be a local path.
+ It is a file path if ``shard=False``. Otherwise, it is a directory path.
+ shard (bool, optional): Whether to save checkpoint a sharded way.
+ If true, the checkpoint will be a folder with the same format as Huggingface transformers checkpoint. Otherwise, it will be a single file. Defaults to False.
+ gather_dtensor (bool, optional): whether to gather the distributed tensor to the first device. Default: True.
+ prefix (str, optional): A prefix added to parameter and buffer
+ names to compose the keys in state_dict. Defaults to None.
+ size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
+ use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
+ """
+ self.checkpoint_io.save_model(
+ model,
+ checkpoint=checkpoint,
+ shard=shard,
+ gather_dtensor=gather_dtensor,
+ prefix=prefix,
+ size_per_shard=size_per_shard,
+ use_safetensors=use_safetensors,
+ )
+
+ def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
+ """Load optimizer from checkpoint.
- def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
+ Args:
+ optimizer (Optimizer): An optimizer boosted by Booster.
+ checkpoint (str): Path to the checkpoint. It must be a local path.
+ It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
+ prefix (str, optional): A prefix added to parameter and buffer
+ names to compose the keys in state_dict. Defaults to None.
+ size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
+ """
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
- def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
- self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard)
+ def save_optimizer(
+ self,
+ optimizer: Optimizer,
+ checkpoint: str,
+ shard: bool = False,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024,
+ ) -> None:
+ """
+ Save optimizer to checkpoint.
+
+ Args:
+ optimizer (Optimizer): An optimizer boosted by Booster.
+ checkpoint (str): Path to the checkpoint. It must be a local path.
+ It is a file path if ``shard=False``. Otherwise, it is a directory path.
+ shard (bool, optional): Whether to save checkpoint a sharded way.
+ If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
+ gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
+ prefix (str, optional): A prefix added to parameter and buffer
+ names to compose the keys in state_dict. Defaults to None.
+ size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
+ """
+ self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)
+
+ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
+ """Save lr scheduler to checkpoint.
- def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
+ Args:
+ lr_scheduler (LRScheduler): A lr scheduler boosted by Booster.
+ checkpoint (str): Path to the checkpoint. It must be a local file path.
+ """
self.checkpoint_io.save_lr_scheduler(lr_scheduler, checkpoint)
- def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
+ def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
+ """Load lr scheduler from checkpoint.
+
+ Args:
+ lr_scheduler (LRScheduler): A lr scheduler boosted by Booster.
+ checkpoint (str): Path to the checkpoint. It must be a local file path.
+ """
self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)
diff --git a/colossalai/booster/mixed_precision/__init__.py b/colossalai/booster/mixed_precision/__init__.py
index 3cf0ad28cdbe78b24e1c8fd18fd6bd473e095bd6..68c6221ec809865284826c189f98b5ddf90a3975 100644
--- a/colossalai/booster/mixed_precision/__init__.py
+++ b/colossalai/booster/mixed_precision/__init__.py
@@ -1,19 +1,27 @@
from .bf16 import BF16MixedPrecision
from .fp8 import FP8MixedPrecision
from .fp16_apex import FP16ApexMixedPrecision
+from .fp16_naive import FP16NaiveMixedPrecision
from .fp16_torch import FP16TorchMixedPrecision
from .mixed_precision_base import MixedPrecision
__all__ = [
- 'MixedPrecision', 'mixed_precision_factory', 'FP16_Apex_MixedPrecision', 'FP16_Torch_MixedPrecision',
- 'FP32_MixedPrecision', 'BF16_MixedPrecision', 'FP8_MixedPrecision'
+ "MixedPrecision",
+ "mixed_precision_factory",
+ "FP16_Apex_MixedPrecision",
+ "FP16_Torch_MixedPrecision",
+ "FP32_MixedPrecision",
+ "BF16_MixedPrecision",
+ "FP8_MixedPrecision",
+ "FP16NaiveMixedPrecision",
]
_mixed_precision_mapping = {
- 'fp16': FP16TorchMixedPrecision,
- 'fp16_apex': FP16ApexMixedPrecision,
- 'bf16': BF16MixedPrecision,
- 'fp8': FP8MixedPrecision
+ "fp16": FP16TorchMixedPrecision,
+ "fp16_apex": FP16ApexMixedPrecision,
+ "fp16_naive": FP16NaiveMixedPrecision,
+ "bf16": BF16MixedPrecision,
+ "fp8": FP8MixedPrecision,
}
@@ -29,5 +37,5 @@ def mixed_precision_factory(mixed_precision_type: str) -> MixedPrecision:
return _mixed_precision_mapping[mixed_precision_type]()
else:
raise ValueError(
- f'Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}'
+ f"Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}"
)
diff --git a/colossalai/booster/mixed_precision/fp16_apex.py b/colossalai/booster/mixed_precision/fp16_apex.py
index 266a750734b14ade3c5f53fc71ef276d09e1ec83..2fa7b54cdd3094fc2ded6fabd77e35ad77d56aa3 100644
--- a/colossalai/booster/mixed_precision/fp16_apex.py
+++ b/colossalai/booster/mixed_precision/fp16_apex.py
@@ -1,5 +1,40 @@
+from typing import Any, Optional, Union
+
+import torch
+
from .mixed_precision_base import MixedPrecision
class FP16ApexMixedPrecision(MixedPrecision):
- pass
+ """
+ Precision for mixed precision training in FP16 using apex AMP.
+
+ Args:
+ opt_level(str, optional, default="O1" ): Pure or mixed precision optimization level. Accepted values are “O0”, “O1”, “O2”, and “O3”, explained in detail above Apex AMP Documentation.
+ cast_model_type (torch.dtype, optional, default=None): Casts your model’s parameters and buffers to the desired type.
+ patch_torch_functions (bool, optional, default=None): Patch all Torch functions and Tensor methods to perform Tensor Core-friendly ops like GEMMs and convolutions in FP16, and any ops that benefit from FP32 precision in FP32.
+ keep_batchnorm_fp32 (bool or str, optional, default=None): To enhance precision and enable cudnn batchnorm (which improves performance), it’s often beneficial to keep batchnorm weights in FP32 even if the rest of the model is FP16.
+ master_weights (bool, optional, default=None): Maintain FP32 master weights to accompany any FP16 model weights. FP32 master weights are stepped by the optimizer to enhance precision and capture small gradients.
+ loss_scale (float or str, optional, default=None): If loss_scale is a float value, use this value as the static (fixed) loss scale. If loss_scale is the string "dynamic", adaptively adjust the loss scale over time. Dynamic loss scale adjustments are performed by Amp automatically.
+ cast_model_outputs (torch.dpython:type, optional, default=None): Option to ensure that the outputs of your model(s) are always cast to a particular type regardless of opt_level.
+ num_losses(int, optional, default=1): Option to tell AMP in advance how many losses/backward passes you plan to use. When used in conjunction with the loss_id argument to `amp.scale_loss`, enables Amp to use a different loss scale per loss/backward pass, which can improve stability. If num_losses is left to 1, Amp will still support multiple losses/backward passes, but use a single global loss scale for all of them.
+ verbosity(int, default=1): Set to 0 to suppress Amp-related output.
+ min_loss_scale(float, default=None): Sets a floor for the loss scale values that can be chosen by dynamic loss scaling. The default value of None means that no floor is imposed. If dynamic loss scaling is not used, min_loss_scale is ignored.
+ max_loss_scale(float, default=2.**24 ): Sets a ceiling for the loss scale values that can be chosen by dynamic loss scaling. If dynamic loss scaling is not used, max_loss_scale is ignored.
+ """
+
+ def __init__(
+ self,
+ opt_level: Optional[str] = "O1",
+ cast_model_type: torch.dtype = None,
+ patch_torch_functions: bool = None,
+ keep_batchnorm_fp32: Union[bool, str] = None,
+ master_weights: bool = None,
+ loss_scale: Union[float, str] = None,
+ cast_model_outputs: Any = None,
+ num_losses: Optional[int] = 1,
+ verbosity: int = 1,
+ min_loss_scale: float = None,
+ max_loss_scale: float = 2.0**24,
+ ) -> None:
+ pass
diff --git a/colossalai/booster/mixed_precision/fp16_naive.py b/colossalai/booster/mixed_precision/fp16_naive.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5624a9d7477a407246e72d3adf6a66400e46e9e
--- /dev/null
+++ b/colossalai/booster/mixed_precision/fp16_naive.py
@@ -0,0 +1,28 @@
+from .mixed_precision_base import MixedPrecision
+
+
+class FP16NaiveMixedPrecision(MixedPrecision):
+ """
+ Precision for mixed precision training in FP16 using naive AMP.
+
+ Args:
+ log_num_zeros_in_grad(bool): return number of zeros in the gradients.
+ initial_scale(int): initial scale of gradient scaler.
+ growth_factor(int): the growth rate of loss scale.
+ backoff_factor(float): the decrease rate of loss scale.
+ hysteresis(int): delay shift in dynamic loss scaling.
+ max_scale(int): maximum loss scale allowed.
+ verbose(bool): if set to `True`, will print debug info.
+ """
+
+ def __init__(
+ self,
+ log_num_zeros_in_grad: bool,
+ initial_scale: int,
+ growth_factor: int,
+ backoff_factor: float,
+ hysteresis: int,
+ max_scale: int,
+ verbose: bool = None,
+ ) -> None:
+ pass
diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py
index 9999aa5e0eb475b8303b76382d9287b3ac876696..7dce6e6da33e8e4d9b0ac23452afb8fa24299884 100644
--- a/colossalai/booster/mixed_precision/fp16_torch.py
+++ b/colossalai/booster/mixed_precision/fp16_torch.py
@@ -1,4 +1,4 @@
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Callable, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -9,7 +9,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
from .mixed_precision_base import MixedPrecision
-__all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule']
+__all__ = ["FP16_Torch_MixedPrecision", "TorchAMPOptimizer", "TorchAMPModule"]
class TorchAMPOptimizer(OptimizerWrapper):
@@ -29,17 +29,21 @@ class TorchAMPOptimizer(OptimizerWrapper):
calls that may cause the scale to increase. Default: 2000.
"""
- def __init__(self,
- optim: Optimizer,
- init_scale: float = 2.**16,
- growth_factor: float = 2.0,
- backoff_factor: float = 0.5,
- growth_interval: int = 2000) -> None:
+ def __init__(
+ self,
+ optim: Optimizer,
+ init_scale: float = 2.0**16,
+ growth_factor: float = 2.0,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 2000,
+ ) -> None:
super().__init__(optim)
- self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale,
- growth_factor=growth_factor,
- backoff_factor=backoff_factor,
- growth_interval=growth_interval)
+ self.scaler = torch.cuda.amp.GradScaler(
+ init_scale=init_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ )
def backward(self, loss: Tensor, *args, **kwargs) -> None:
scaled_loss = self.scale_loss(loss)
@@ -60,12 +64,14 @@ class TorchAMPOptimizer(OptimizerWrapper):
self.unscale_grad()
super().clip_grad_by_value(clip_value, *args, **kwargs)
- def clip_grad_by_norm(self,
- max_norm: Union[float, int],
- norm_type: Union[float, int] = 2.0,
- error_if_nonfinite: bool = False,
- *args,
- **kwargs) -> None:
+ def clip_grad_by_norm(
+ self,
+ max_norm: Union[float, int],
+ norm_type: Union[float, int] = 2.0,
+ error_if_nonfinite: bool = False,
+ *args,
+ **kwargs,
+ ) -> None:
self.unscale_grad()
super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
@@ -102,23 +108,30 @@ class FP16TorchMixedPrecision(MixedPrecision):
calls that may cause the scale to increase. Default: 2000.
"""
- def __init__(self,
- init_scale: float = 2.**16,
- growth_factor: float = 2.0,
- backoff_factor: float = 0.5,
- growth_interval: int = 2000) -> None:
+ def __init__(
+ self,
+ init_scale: float = 2.0**16,
+ growth_factor: float = 2.0,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 2000,
+ ) -> None:
super().__init__()
- self.torch_amp_kwargs = dict(init_scale=init_scale,
- growth_factor=growth_factor,
- backoff_factor=backoff_factor,
- growth_interval=growth_interval)
-
- def configure(self,
- model: nn.Module,
- optimizer: Optimizer,
- criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
+ self.torch_amp_kwargs = dict(
+ init_scale=init_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ )
+
+ def configure(
+ self,
+ model: nn.Module,
+ optimizer: Optional[Optimizer] = None,
+ criterion: Optional[Callable] = None,
+ ) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
model = TorchAMPModule(model)
- optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
+ if optimizer is not None:
+ optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
if criterion is not None:
criterion = TorchAMPModule(criterion)
return model, optimizer, criterion
diff --git a/colossalai/booster/mixed_precision/mixed_precision_base.py b/colossalai/booster/mixed_precision/mixed_precision_base.py
index 2490e9811ccf3ef71a1dcb90d30ddb794fc82d04..a86fdfc17eaf4df45de733e03c6384be6ccd2779 100644
--- a/colossalai/booster/mixed_precision/mixed_precision_base.py
+++ b/colossalai/booster/mixed_precision/mixed_precision_base.py
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
-from typing import Callable, Tuple
+from typing import Callable, Optional, Tuple
import torch.nn as nn
from torch.optim import Optimizer
@@ -13,9 +13,11 @@ class MixedPrecision(ABC):
"""
@abstractmethod
- def configure(self,
- model: nn.Module,
- optimizer: Optimizer,
- criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
+ def configure(
+ self,
+ model: nn.Module,
+ optimizer: Optional[Optimizer] = None,
+ criterion: Optional[Callable] = None,
+ ) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
# TODO: implement this method
pass
diff --git a/colossalai/booster/plugin/__init__.py b/colossalai/booster/plugin/__init__.py
index aa45bcb59ad7b9b25a06bea5fa1c0f776e6870ae..62f3708fc62972051c05d8ffdfbd0ff0f8ca410e 100644
--- a/colossalai/booster/plugin/__init__.py
+++ b/colossalai/booster/plugin/__init__.py
@@ -1,6 +1,15 @@
from .gemini_plugin import GeminiPlugin
+from .hybrid_parallel_plugin import HybridParallelPlugin
from .low_level_zero_plugin import LowLevelZeroPlugin
from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin
-__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin']
+__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"]
+
+import torch
+from packaging import version
+
+if version.parse(torch.__version__) >= version.parse("1.12.0"):
+ from .torch_fsdp_plugin import TorchFSDPPlugin
+
+ __all__.append("TorchFSDPPlugin")
diff --git a/colossalai/booster/plugin/dp_plugin_base.py b/colossalai/booster/plugin/dp_plugin_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2dd00453e32592b980b0a34c967829fc94e2cde
--- /dev/null
+++ b/colossalai/booster/plugin/dp_plugin_base.py
@@ -0,0 +1,66 @@
+import random
+
+import numpy as np
+import torch
+import torch.distributed as dist
+from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
+
+from .plugin_base import Plugin
+
+
+class DPPluginBase(Plugin):
+ """This is a base class for all DP plugins. It sets up world size and rank, and provides data loader creation."""
+
+ def __init__(self) -> None:
+ super().__init__()
+ assert (
+ dist.is_initialized()
+ ), "torch.distributed is not initialized, please use colossalai.launch to create the distributed environment"
+ self.rank = dist.get_rank()
+ self.world_size = dist.get_world_size()
+
+ def prepare_dataloader(
+ self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
+ ):
+ r"""
+ Prepare a dataloader for distributed training. The dataloader will be wrapped by
+ `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
+
+
+ Args:
+ dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
+ shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
+ seed (int, optional): Random worker seed for sampling, defaults to 1024.
+ add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
+ drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
+ is not divisible by the batch size. If False and the size of dataset is not divisible by
+ the batch size, then the last batch will be smaller, defaults to False.
+ pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
+ num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
+ kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
+ `DataLoader `_.
+
+ Returns:
+ :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
+ """
+ _kwargs = kwargs.copy()
+ sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
+
+ # Deterministic dataloader
+ def seed_worker(worker_id):
+ worker_seed = seed
+ np.random.seed(worker_seed)
+ torch.manual_seed(worker_seed)
+ random.seed(worker_seed)
+
+ return DataLoader(
+ dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ worker_init_fn=seed_worker,
+ drop_last=drop_last,
+ pin_memory=pin_memory,
+ num_workers=num_workers,
+ **_kwargs,
+ )
diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py
index deda00d8a7b3b6849a1cf5b1db0b796a1a1c7d89..ca722a0768dc0b6dd5b3f5085980a7a2c4b6a57b 100644
--- a/colossalai/booster/plugin/gemini_plugin.py
+++ b/colossalai/booster/plugin/gemini_plugin.py
@@ -1,143 +1,278 @@
-import random
-import warnings
-from typing import Callable, List, Optional, Tuple, Union
+import gc
+import logging
+import os
+from pathlib import Path
+from typing import Callable, Iterator, List, Optional, Tuple
-import numpy as np
import torch
-import torch.distributed as dist
import torch.nn as nn
-from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
-from torch.utils.data.distributed import DistributedSampler
-from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
-from colossalai.checkpoint_io.utils import save_state_dict
+from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
+from colossalai.checkpoint_io.utils import (
+ get_model_base_filenames,
+ get_optimizer_base_filenames,
+ load_shard_state_dict,
+ save_config_file,
+ save_state_dict,
+ save_state_dict_shards,
+)
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
-from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
+from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats
-from .plugin_base import Plugin
+from .dp_plugin_base import DPPluginBase
-__all__ = ['GeminiPlugin']
+__all__ = ["GeminiPlugin"]
+SUPPORTED_PRECISION = ["fp16", "bf16"]
+PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}
-class GeminiCheckpointIO(GeneralCheckpointIO):
+class GeminiCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
+ def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
+ """
+ Save sharded model to checkpoint but only on master process.
+ The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
+ As there is communication when getting state dict, model.state_dict() must be called on all processes.
+ """
+ assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
+ state_dict = model.state_dict(only_rank_0=True)
+ if self.coordinator.is_master():
+ save_state_dict(state_dict, checkpoint, use_safetensors)
+
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
"""
Load model from checkpoint with automatic unwrapping.
+ The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
"""
- # the model should be unwrapped in self.load_model via ModelWrapper.unwrap
- return super().load_unsharded_model(model, checkpoint, strict=strict)
+ assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
+ super().load_unsharded_model(model, checkpoint, strict=strict)
- def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
+ def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool):
"""
- Save model to checkpoint but only on master process.
+ Save unsharded optimizer state dict to checkpoint.
+ After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
+ As there is communication when getting state dict, optimizer.state_dict() must be called on all processes.
+ The saving process will only be executed by master rank.
"""
- # the model should be unwrapped in self.load_model via ModelWrapper.unwrap
- # as there is communication when get state dict, this must be called on all processes
- state_dict = model.state_dict(only_rank_0=True)
+ assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
+ state_dict = optimizer.state_dict()
if self.coordinator.is_master():
- save_state_dict(state_dict, checkpoint, use_safetensors)
+ save_state_dict(state_dict, checkpoint, use_safetensors=False)
- def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
+ def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
"""
- Save optimizer to checkpoint but only on master process.
+ Loading unsharded optimizer from checkpoint file.
+ For each process, only loading optimizer states of parameters it controls.
"""
- # TODO(ver217): optimizer state dict is sharded
- super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
+ assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
+ super().load_unsharded_optimizer(optimizer, checkpoint)
- def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
+ def save_sharded_model(
+ self,
+ model: GeminiDDP,
+ checkpoint_path: str,
+ gather_dtensor: bool = False,
+ prefix: Optional[str] = None,
+ max_shard_size: int = 1024,
+ use_safetensors: bool = False,
+ ):
"""
- Save model to checkpoint but only on master process.
+ Save sharded model.
+ As there is communication when getting state dict, model.state_dict() must be called on all processes.
"""
- if self.coordinator.is_master():
- super().save_lr_scheduler(lr_scheduler, checkpoint)
+ assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
+ if os.path.isfile(checkpoint_path):
+ logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
+ return
+
+ Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
+
+ state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
+ weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
+ index_file = CheckpointIndexFile(checkpoint_path)
+
+ # Save shards of optimizer states.
+ is_master = self.coordinator.is_master()
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint_path,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=is_master,
+ use_safetensors=use_safetensors,
+ )
+ # only save the index file on the master rank
+ if self.coordinator.is_master():
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ save_config_file(model.unwrap(), checkpoint_path)
+ logging.info(
+ f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
+
+ def load_sharded_model(
+ self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
+ ):
+ """
+ Load shard model, load model from multiple files.
+ """
+ assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
+ return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
-class GeminiModel(ModelWrapper):
+ def save_sharded_optimizer(
+ self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
+ ):
+ """
+ Save sharded optimizer state dict to checkpoint folder.
+ As there is communication when getting state dict, this must be called on all processes.
+ """
+ assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
+
+ if os.path.isfile(checkpoint):
+ logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
+ return
+
+ Path(checkpoint).mkdir(parents=True, exist_ok=True)
+
+ # Preparing file paths and index file.
+ states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
+ index_file = CheckpointIndexFile(checkpoint)
+
+ # Store the information of param groups to param_group_file.
+ index_file.append_meta_data("param_groups", param_group_file)
+ group_file_path = os.path.join(checkpoint, param_group_file)
+ param_groups = optimizer.get_param_groups_for_saving()
+ torch.save(param_groups, group_file_path)
+
+ # States are broken into shards within max_shard_size.
+ state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True)
+
+ # Save shards of optimizer states.
+ is_master = self.coordinator.is_master()
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=is_master,
+ use_safetensors=False,
+ )
- def __init__(self, module: nn.Module, gemini_config: dict, verbose: bool = False) -> None:
- super().__init__(module)
- self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config, verbose=verbose)
+ # Wrap up index file. Only save it on master rank.
+ if self.coordinator.is_master():
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ logging.info(
+ f"The optimizer is going to be split to checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
+
+ def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
+ """
+ Loading sharded optimizer from checkpoint folder, with index file given.
+ For each process, only loading optimizer states of parameters it controls.
+ """
+ assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
+ if not os.path.isfile(checkpoint_index_file):
+ logging.error(f"Provided path ({checkpoint_index_file}) should be a file")
- def unwrap(self):
- # as save/load state dict is coupled with the GeminiDDP, we only return GeminiDDP model
- return self.module
+ assert isinstance(optimizer, GeminiOptimizer)
+ # Read checkpoint index file.
+ ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
-class GeminiOptimizer(OptimizerWrapper):
+ # Load param_groups.
+ param_group_path = ckpt_index_file.get_param_group_filename()
+ if param_group_path is None:
+ raise RuntimeError(
+ f"Invalid index file path {checkpoint_index_file} for an optimizer. \
+ Lacking param group file under current directory."
+ )
+ saved_param_groups = torch.load(param_group_path)
+ optimizer.load_param_groups(saved_param_groups)
- def __init__(self,
- module: GeminiDDP,
- optimizer: Optimizer,
- zero_optim_config: dict,
- optim_kwargs: dict,
- verbose: bool = False) -> None:
- optimizer = zero_optim_wrapper(module,
- optimizer,
- optim_config=zero_optim_config,
- **optim_kwargs,
- verbose=verbose)
- super().__init__(optimizer)
+ checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
- def backward(self, loss: Tensor, *args, **kwargs):
- self.optim.backward(loss)
+ # Load optimizer states from shard files under checkpoint path.
+ # For each file, only load the states managed by current process.
+ for shard_file in checkpoint_files:
+ state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
+ optimizer.load_param_states(state_dict_shard)
+ del state_dict_shard
+ gc.collect()
- def clip_grad_by_norm(self,
- max_norm: Union[float, int],
- norm_type: Union[float, int] = 2,
- error_if_nonfinite: bool = False,
- *args,
- **kwargs) -> Tensor:
- warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm')
+ optimizer.optimizer_loading_epilogue()
- def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
- raise NotImplementedError('Gemini does not support clip_grad_by_value')
+ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
+ """
+ Save model to checkpoint but only on master process.
+ """
+ if self.coordinator.is_master():
+ super().save_lr_scheduler(lr_scheduler, checkpoint)
-class GeminiPlugin(Plugin):
+class GeminiPlugin(DPPluginBase):
"""
Plugin for Gemini.
- Example:
- >>> from colossalai.booster import Booster
- >>> from colossalai.booster.plugin import GeminiPlugin
- >>>
- >>> model, train_dataset, optimizer, criterion = ...
- >>> plugin = GeminiPlugin()
+ ```python
+ from colossalai.booster import Booster
+ from colossalai.booster.plugin import GeminiPlugin
+
+ model, train_dataset, optimizer, criterion = ...
+ plugin = GeminiPlugin()
- >>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
- >>> booster = Booster(plugin=plugin)
- >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
+ train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
+ booster = Booster(plugin=plugin)
+ model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
+ ```
Args:
- device (torch.device): device to place the model.
- placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
+ chunk_config_dict (dict, optional): chunk configuration dictionary.
+ chunk_init_device (torch.device, optional): device to initialize the chunk.
+ placement_policy (str, optional): "static" and "auto". Defaults to "static".
+ shard_param_frac (float, optional): fraction of parameters to be sharded. Only for "static" placement.
+ If `shard_param_frac` is 1.0, it's equal to zero-3. If `shard_param_frac` is 0.0, it's equal to zero-2. Defaults to 1.0.
+ offload_optim_frac (float, optional): fraction of optimizer states to be offloaded. Only for "static" placement.
+ If `shard_param_frac` is 1.0 and `offload_optim_frac` is 0.0, it's equal to old "cuda" placement. Defaults to 0.0.
+ offload_param_frac (float, optional): fraction of parameters to be offloaded. Only for "static" placement.
+ For efficiency, this argument is useful only when `shard_param_frac` is 1.0 and `offload_optim_frac` is 1.0.
+ If `shard_param_frac` is 1.0, `offload_optim_frac` is 1.0 and `offload_param_frac` is 1.0, it's equal to old "cpu" placement.
+ When using static placement, we recommend users to tune `shard_param_frac` first and then `offload_optim_frac`.
+ Defaults to 0.0.
+ warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8.
+ steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9.
+ precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
- search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32.
+ search_range_m (int, optional): chunk size searching range divided by 2^20. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024.
- min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
- If the aggregate size of parameters is still samller than the minimum chunk size,
+ min_chunk_size_m (float, optional): the minimum chunk size divided by 2^20.
+ If the aggregate size of parameters is still smaller than the minimum chunk size,
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
which will be used when using hybrid CPU optimizer.
This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto".
Defaults to 0.0.
- initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
+ initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**16.
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
@@ -152,17 +287,24 @@ class GeminiPlugin(Plugin):
def __init__(
self,
- device: Optional[torch.device] = None,
- placement_policy: str = "cpu",
+ chunk_config_dict: Optional[dict] = None,
+ chunk_init_device: Optional[torch.device] = None,
+ placement_policy: str = "static",
+ shard_param_frac: float = 1.0, # only for static placement
+ offload_optim_frac: float = 0.0, # only for static placement
+ offload_param_frac: float = 0.0, # only for static placement
+ warmup_non_model_data_ratio: float = 0.8, # only for auto placement
+ steady_cuda_cap_ratio: float = 0.9, # only for auto placement
+ precision: str = "fp16",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
- search_range_mb: int = 32,
+ search_range_m: int = 32,
hidden_dim: Optional[int] = None,
- min_chunk_size_mb: float = 32,
+ min_chunk_size_m: float = 32,
memstats: Optional[MemStats] = None,
gpu_margin_mem_ratio: float = 0.0,
- initial_scale: float = 2**32,
+ initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
@@ -173,32 +315,40 @@ class GeminiPlugin(Plugin):
norm_type: float = 2.0,
verbose: bool = False,
) -> None:
-
- assert dist.is_initialized(
- ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment'
- self.rank = dist.get_rank()
- self.world_size = dist.get_world_size()
+ super().__init__()
+ assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
self.gemini_config = dict(
- device=(device or get_current_device()),
+ chunk_config_dict=chunk_config_dict,
+ chunk_init_device=(chunk_init_device or get_current_device()),
placement_policy=placement_policy,
+ shard_param_frac=shard_param_frac,
+ offload_optim_frac=offload_optim_frac,
+ offload_param_frac=offload_param_frac,
+ warmup_non_model_data_ratio=warmup_non_model_data_ratio,
+ steady_cuda_cap_ratio=steady_cuda_cap_ratio,
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=strict_ddp_mode,
- search_range_mb=search_range_mb,
+ search_range_m=search_range_m,
hidden_dim=hidden_dim,
- min_chunk_size_mb=min_chunk_size_mb,
+ min_chunk_size_m=min_chunk_size_m,
memstats=memstats,
+ mixed_precision=PRECISION_STR_TO_DTYPE[precision],
+ )
+ self.zero_optim_config = dict(
+ gpu_margin_mem_ratio=gpu_margin_mem_ratio,
+ )
+ self.optim_kwargs = dict(
+ initial_scale=initial_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ hysteresis=hysteresis,
+ min_scale=min_scale,
+ max_scale=max_scale,
+ max_norm=max_norm,
+ norm_type=norm_type,
)
- self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio,)
- self.optim_kwargs = dict(initial_scale=initial_scale,
- growth_factor=growth_factor,
- backoff_factor=backoff_factor,
- growth_interval=growth_interval,
- hysteresis=hysteresis,
- min_scale=min_scale,
- max_scale=max_scale,
- max_norm=max_norm,
- norm_type=norm_type)
self.verbose = verbose
def support_no_sync(self) -> bool:
@@ -208,74 +358,22 @@ class GeminiPlugin(Plugin):
return True
def supported_precisions(self) -> List[str]:
- return ['fp16']
+ return SUPPORTED_PRECISION
def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
- return ['cuda']
-
- def prepare_train_dataloader(self,
- dataset,
- batch_size,
- shuffle=False,
- seed=1024,
- drop_last=False,
- pin_memory=False,
- num_workers=0,
- **kwargs):
- r"""
- Prepare a dataloader for distributed training. The dataloader will be wrapped by
- `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
-
- Note:
- 1. Evaluation datasets should not be passed to this function.
-
- Args:
- dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
- shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
- seed (int, optional): Random worker seed for sampling, defaults to 1024.
- add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
- drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
- is not divisible by the batch size. If False and the size of dataset is not divisible by
- the batch size, then the last batch will be smaller, defaults to False.
- pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
- num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
- kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
- `DataLoader `_.
-
- Returns:
- :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
- """
- _kwargs = kwargs.copy()
- sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
-
- # Deterministic dataloader
- def seed_worker(worker_id):
- worker_seed = seed
- np.random.seed(worker_seed)
- torch.manual_seed(worker_seed)
- random.seed(worker_seed)
-
- return DataLoader(dataset,
- batch_size=batch_size,
- sampler=sampler,
- worker_init_fn=seed_worker,
- drop_last=drop_last,
- pin_memory=pin_memory,
- num_workers=num_workers,
- **_kwargs)
+ return ["cuda"]
def configure(
self,
model: nn.Module,
- optimizer: Optimizer,
- criterion: Callable = None,
- dataloader: DataLoader = None,
- lr_scheduler: LRScheduler = None,
- ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
-
+ optimizer: Optional[Optimizer] = None,
+ criterion: Optional[Callable] = None,
+ dataloader: Optional[DataLoader] = None,
+ lr_scheduler: Optional[LRScheduler] = None,
+ ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper):
# convert model to sync bn
# FIXME(ver217): gemini does not support sync bn
@@ -287,11 +385,12 @@ class GeminiPlugin(Plugin):
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
# wrap the model with Gemini
- model = GeminiModel(model, self.gemini_config, self.verbose)
+ model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose)
- if not isinstance(optimizer, OptimizerWrapper):
- optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
- self.verbose)
+ if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
+ optimizer = GeminiOptimizer(
+ optimizer, model, **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose
+ )
return model, optimizer, criterion, dataloader, lr_scheduler
@@ -300,3 +399,6 @@ class GeminiPlugin(Plugin):
def get_checkpoint_io(self) -> CheckpointIO:
return GeminiCheckpointIO()
+
+ def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
+ raise NotImplementedError
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
new file mode 100644
index 0000000000000000000000000000000000000000..479ccc3eb36e4ae9bb9df430313e912258717cd4
--- /dev/null
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -0,0 +1,579 @@
+import random
+from contextlib import nullcontext
+from functools import partial
+from types import MethodType
+from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
+
+import numpy as np
+import torch
+import torch.distributed as dist
+from torch.distributed import ProcessGroup
+from torch.nn import Module, SyncBatchNorm
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+from torch.utils._pytree import tree_map
+from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
+
+from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
+from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
+from colossalai.cluster import ProcessGroupMesh
+from colossalai.interface import ModelWrapper, OptimizerWrapper
+from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
+from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer import ShardConfig, ShardFormer
+from colossalai.shardformer.policies.base_policy import Policy
+from colossalai.zero.low_level import LowLevelZeroOptimizer
+
+from .pp_plugin_base import PipelinePluginBase
+
+DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
+
+
+def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
+ if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
+ return x.to(dtype)
+ return x
+
+
+class HybridParallelModule(ModelWrapper):
+ def __init__(
+ self,
+ module: Module,
+ precision: str,
+ shard_config: ShardConfig,
+ dp_group: ProcessGroup,
+ use_ddp: bool,
+ ddp_config: dict,
+ custom_policy: Policy,
+ ) -> None:
+ self.stage_manager = shard_config.pipeline_stage_manager
+ self.dp_group = dp_group
+
+ shardformer = ShardFormer(shard_config)
+ if custom_policy is not None:
+ assert isinstance(custom_policy, object)
+ module, self.shared_params = shardformer.optimize(module, policy=custom_policy)
+
+ # setting process groups for shared parameters
+ self.shared_param_process_groups = []
+ for shared_param in self.shared_params:
+ if len(shared_param) > 0:
+ self.shared_param_process_groups.append(
+ self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))
+ )
+
+ # setting mixed_precision
+ self.mixed_precision = None
+ if precision == "fp16":
+ self.mixed_precision = torch.float16
+ elif precision == "bf16":
+ self.mixed_precision = torch.bfloat16
+ if self.mixed_precision is not None:
+ module = module.to(self.mixed_precision)
+ module = module.cuda()
+
+ # setting input type cast when using mixed precision
+ self.convert_fn = None
+ if self.mixed_precision is not None:
+ self.convert_fn = partial(_convert_floating_point, dtype=self.mixed_precision)
+
+ # setting ddp configs
+ if use_ddp:
+ # convert model to sync bn
+ module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group)
+ # wrap the model with PyTorch DDP
+ module = DDP(module, process_group=dp_group, **ddp_config)
+
+ super().__init__(module)
+
+ def sync_shared_params(self):
+ for shared_param, group in zip(self.shared_params, self.shared_param_process_groups):
+ if self.stage_manager.stage in shared_param:
+ param = shared_param[self.stage_manager.stage]
+ dist.all_reduce(param.grad, group=group)
+ dist.barrier()
+
+ def no_sync(self) -> Iterator[None]:
+ # no sync grads across data parallel
+ return nullcontext()
+
+ def sync_grads(self):
+ # sync grad across data parallel
+ if self.dp_group.size() == 1:
+ return
+ for p in self.module.parameters():
+ if p.grad is not None:
+ dist.all_reduce(p.grad, group=self.dp_group)
+ p.grad.div_(self.dp_group.size())
+
+ def forward(self, *args, **kwargs):
+ if self.convert_fn is not None:
+ args = tree_map(self.convert_fn, args)
+ kwargs = tree_map(self.convert_fn, kwargs)
+ return super().forward(*args, **kwargs)
+
+ def unwrap(self):
+ module = super().unwrap()
+ if isinstance(module, DDP):
+ module = module.module
+ return module
+
+
+def get_param_info(optim: Optimizer):
+ # Get a backup of necessary information of parameters for future use, which includes:
+ # 1. A complete param_group, with params in the form of param_id
+ # 2. A mapping from param address (obtained using id(param)) to integer param_id
+ # 3. A mapping from integer param_id to param address.
+ # 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding.
+ # When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer.
+
+ if optim is None:
+ return {}
+ param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}}
+ start_index = 0
+ for group in optim.param_groups:
+ packed_group = {k: v for k, v in group.items() if k != "params"}
+ packed_group["params"] = []
+
+ for param_id, param in enumerate(group["params"], start_index):
+ original_shape = param.shape if isinstance(param, torch.Tensor) else None
+ packed_group["params"].append(param_id)
+ param_info["param2id"][id(param)] = param_id
+ param_info["id2param"][param_id] = id(param)
+ param_info["param2shape"][id(param)] = original_shape
+
+ param_info["param_groups"].append(packed_group)
+ start_index += len(group["params"])
+
+ return param_info
+
+
+def init_pipeline_optimizer(optim: Optimizer, model: Module):
+ model_params = set(model.parameters())
+ new_param_groups = []
+ for group in optim.param_groups:
+ params = [p for p in group["params"] if p in model_params]
+ new_param_groups.append({**group, "params": params})
+ optim.__setstate__({"param_groups": new_param_groups})
+
+
+class HybridParallelNaiveOptimizer(OptimizerWrapper):
+ def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict):
+ self.param_info = param_info
+ if use_pipeline:
+ init_pipeline_optimizer(optim, model)
+ super().__init__(optim)
+
+ def update_master_params(self, model: Module):
+ pass
+
+ def get_working_to_master_map(self):
+ return None
+
+ def get_master_to_working_map(self):
+ return None
+
+
+class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
+ def __init__(
+ self,
+ optim: Optimizer,
+ model: Module,
+ use_pipeline: bool,
+ param_info: OrderedDict,
+ precision: str = "fp16",
+ initial_scale: float = 2**16,
+ min_scale: float = 1,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ max_scale: float = 2**32,
+ max_norm: float = 0,
+ ):
+ self.param_info = param_info
+ if use_pipeline:
+ init_pipeline_optimizer(optim, model)
+ super().__init__(
+ optim,
+ precision,
+ initial_scale,
+ min_scale,
+ growth_factor,
+ backoff_factor,
+ growth_interval,
+ hysteresis,
+ max_scale,
+ max_norm,
+ )
+
+
+class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
+ def __init__(
+ self,
+ optimizer: Optimizer,
+ model: Module,
+ use_pipeline: bool,
+ param_info: OrderedDict,
+ initial_scale: int = 2**16, # grad scaler config
+ min_scale: int = 1,
+ growth_factor: float = 2.0,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 2000,
+ hysteresis: int = 2,
+ max_scale: int = 2**24,
+ clip_grad_norm: float = 0.0, # grad clipping
+ verbose: bool = False,
+ reduce_bucket_size: int = 1024 * 1024, # communication
+ communication_dtype: Optional[torch.dtype] = None,
+ overlap_communication: bool = True,
+ partition_grad: bool = False, # stage 2 flag
+ cpu_offload: bool = False, # cpu offload
+ dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
+ tp_process_group: Optional[ProcessGroup] = None, # if using tp
+ forced_dtype: Optional[torch.dtype] = None,
+ ):
+ self.param_info = param_info
+ if use_pipeline:
+ init_pipeline_optimizer(optimizer, model)
+ super().__init__(
+ optimizer,
+ initial_scale,
+ min_scale,
+ growth_factor,
+ backoff_factor,
+ growth_interval,
+ hysteresis,
+ max_scale,
+ clip_grad_norm,
+ verbose,
+ reduce_bucket_size,
+ communication_dtype,
+ overlap_communication,
+ partition_grad,
+ cpu_offload,
+ dp_process_group,
+ tp_process_group,
+ forced_dtype,
+ )
+
+
+class HybridParallelPlugin(PipelinePluginBase):
+ """
+ Plugin for Hybrid Parallel Training.
+ Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
+ The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
+
+ ```python
+ from colossalai.booster import Booster
+ from colossalai.booster.plugin import HybridParallelPlugin
+
+ model, train_dataset, optimizer, criterion = ...
+ plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
+
+ train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
+ booster = Booster(plugin=plugin)
+ model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
+ ```
+
+ Args:
+ tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
+ pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
+ precision (str, optional): Specifies the precision of parameters during training.
+ Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
+ Defaults to 'fp16'.
+ zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2].
+ When set to 0, ZeRO will not be used. Defaults to 0.
+ enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
+ Currently all the optimization methods include fused normalization, flash attention and JIT.
+ Defaults to False.
+ enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False.
+ enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
+ enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
+ enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
+ enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
+ num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
+ microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
+ Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
+ If ``num_microbatches`` is provided, this will be ignored. Defaults to None.
+ initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
+ min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
+ growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
+ backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5.
+ growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000.
+ hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2.
+ max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32.
+ max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0.
+ broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True.
+ ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25.
+ find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False.
+ check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False.
+ gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False.
+ static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False.
+ zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12.
+ cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
+ communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
+ overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
+ custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ tp_size: int,
+ pp_size: int,
+ precision: str = "fp16",
+ zero_stage: int = 0,
+ enable_all_optimization: bool = False,
+ enable_fused_normalization: bool = False,
+ enable_flash_attention: bool = False,
+ enable_jit_fused: bool = False,
+ enable_sequence_parallelism: bool = False,
+ enable_sequence_overlap: bool = False,
+ num_microbatches: Optional[int] = None,
+ microbatch_size: Optional[int] = None,
+ initial_scale: float = 2**16,
+ min_scale: float = 1,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ max_scale: float = 2**32,
+ max_norm: float = 0,
+ broadcast_buffers: bool = True,
+ ddp_bucket_cap_mb: int = 25,
+ find_unused_parameters: bool = False,
+ check_reduction: bool = False,
+ gradient_as_bucket_view: bool = False,
+ static_graph: bool = False,
+ zero_bucket_size_in_m: int = 12,
+ cpu_offload: bool = False,
+ communication_dtype: Optional[torch.dtype] = None,
+ overlap_communication: bool = True,
+ custom_policy: Policy = None,
+ ) -> None:
+ super().__init__()
+ assert (
+ dist.get_world_size() % (tp_size * pp_size) == 0
+ ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
+
+ if enable_sequence_parallelism:
+ assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
+
+ self.tp_size = tp_size
+ self.pp_size = pp_size
+ self.dp_size = dist.get_world_size() // (tp_size * pp_size)
+ self.precision = precision
+ self.zero_stage = zero_stage
+ self.cpu_offload = cpu_offload
+ self.enable_all_optimization = enable_all_optimization
+ self.enable_fused_normalization = enable_fused_normalization
+ self.enable_flash_attention = enable_flash_attention
+ self.enable_jit_fused = enable_jit_fused
+ self.enable_sequence_parallelism = enable_sequence_parallelism
+ self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
+ self.stage_manager = None
+ self.schedule = None
+ self.custom_policy = custom_policy
+ assert zero_stage in (0, 1, 2)
+ if self.pp_size > 1:
+ assert (
+ num_microbatches is not None or microbatch_size is not None
+ ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
+ assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
+ self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
+ self.schedule = OneForwardOneBackwardSchedule(
+ self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
+ )
+ self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
+ self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
+ self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
+ self.shard_config = ShardConfig(
+ tensor_parallel_process_group=self.tp_group,
+ pipeline_stage_manager=self.stage_manager,
+ enable_tensor_parallelism=self.tp_size > 1,
+ enable_all_optimization=self.enable_all_optimization,
+ enable_fused_normalization=self.enable_fused_normalization,
+ enable_flash_attention=self.enable_flash_attention,
+ enable_jit_fused=self.enable_jit_fused,
+ enable_sequence_parallelism=enable_sequence_parallelism,
+ enable_sequence_overlap=enable_sequence_overlap,
+ )
+ self.amp_config = dict(
+ initial_scale=initial_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ hysteresis=hysteresis,
+ min_scale=min_scale,
+ max_scale=max_scale,
+ )
+
+ self.ddp_config = dict(
+ broadcast_buffers=broadcast_buffers,
+ bucket_cap_mb=ddp_bucket_cap_mb,
+ find_unused_parameters=find_unused_parameters,
+ check_reduction=check_reduction,
+ gradient_as_bucket_view=gradient_as_bucket_view,
+ static_graph=static_graph,
+ )
+
+ self.zero_config = dict(
+ reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
+ communication_dtype=communication_dtype,
+ overlap_communication=overlap_communication,
+ cpu_offload=cpu_offload,
+ partition_grad=(self.zero_stage == 2),
+ )
+
+ self.max_norm = max_norm
+
+ @property
+ def enable_pipeline_parallelism(self) -> bool:
+ return self.pp_size > 1
+
+ def supported_devices(self) -> List[str]:
+ return ["cuda"]
+
+ def supported_precisions(self) -> List[str]:
+ return ["fp16", "bf16", "fp32"]
+
+ def control_device(self) -> bool:
+ return True
+
+ def control_precision(self) -> bool:
+ return True
+
+ def support_no_sync(self) -> bool:
+ return False
+
+ def control_checkpoint_io(self) -> bool:
+ return True
+
+ def configure(
+ self,
+ model: Module,
+ optimizer: Optional[Optimizer] = None,
+ criterion: Optional[Callable] = None,
+ dataloader: Optional[DataLoader] = None,
+ lr_scheduler: Optional[LRScheduler] = None,
+ ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
+ param_info = get_param_info(optimizer)
+ if not isinstance(model, ModelWrapper):
+ use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
+ model = HybridParallelModule(
+ model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy
+ )
+ if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
+ if self.zero_stage == 0:
+ if self.precision in ["fp16", "bf16"]:
+ optimizer = HybridParallelAMPOptimizer(
+ optimizer,
+ model,
+ use_pipeline=self.enable_pipeline_parallelism,
+ param_info=param_info,
+ precision=self.precision,
+ max_norm=self.max_norm,
+ **self.amp_config,
+ )
+ else:
+ optimizer = HybridParallelNaiveOptimizer(
+ optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
+ )
+ else:
+ assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
+ assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
+ optimizer = HybridParallelZeroOptimizer(
+ optimizer,
+ model,
+ use_pipeline=self.enable_pipeline_parallelism,
+ param_info=param_info,
+ dp_process_group=self.dp_group,
+ tp_process_group=self.tp_group,
+ verbose=True,
+ clip_grad_norm=self.max_norm,
+ **self.zero_config,
+ **self.amp_config,
+ )
+ # inject update_master_params
+ model.update_master_params = MethodType(optimizer.update_master_params, model)
+ return model, optimizer, criterion, dataloader, lr_scheduler
+
+ def execute_pipeline(
+ self,
+ data_iter: Iterator,
+ model: HybridParallelModule,
+ criterion: Callable[[Any, Any], torch.Tensor],
+ optimizer: Optional[
+ Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, HybridParallelZeroOptimizer]
+ ] = None,
+ return_loss: bool = True,
+ return_outputs: bool = False,
+ ) -> dict:
+ assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
+ # return loss or outputs if needed
+ ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
+ with ctx:
+ outputs = self.schedule.forward_backward_step(
+ model, data_iter, criterion, optimizer, return_loss, return_outputs
+ )
+ model.sync_shared_params()
+ if isinstance(optimizer, HybridParallelZeroOptimizer):
+ optimizer.sync_grad()
+ else:
+ model.sync_grads()
+ return outputs
+
+ def prepare_dataloader(
+ self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
+ ):
+ r"""
+ Prepare a dataloader for distributed training. The dataloader will be wrapped by
+ `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
+
+
+ Args:
+ dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
+ shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
+ seed (int, optional): Random worker seed for sampling, defaults to 1024.
+ add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
+ drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
+ is not divisible by the batch size. If False and the size of dataset is not divisible by
+ the batch size, then the last batch will be smaller, defaults to False.
+ pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
+ num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
+ kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
+ `DataLoader `_.
+
+ Returns:
+ :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
+ """
+ _kwargs = kwargs.copy()
+ sampler = DistributedSampler(
+ dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
+ )
+
+ # Deterministic dataloader
+ def seed_worker(worker_id):
+ worker_seed = seed
+ np.random.seed(worker_seed)
+ torch.manual_seed(worker_seed)
+ random.seed(worker_seed)
+
+ return DataLoader(
+ dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ worker_init_fn=seed_worker,
+ drop_last=drop_last,
+ pin_memory=pin_memory,
+ num_workers=num_workers,
+ **_kwargs,
+ )
+
+ def get_checkpoint_io(self) -> CheckpointIO:
+ return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
+
+ def no_sync(self, model: Module) -> Iterator[None]:
+ raise NotImplementedError
diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py
index 969c430bd317600091e69a68276a5f14006d5420..dffa4ce164efe2ca0abb6106e76ea416dbf9ce3d 100644
--- a/colossalai/booster/plugin/low_level_zero_plugin.py
+++ b/colossalai/booster/plugin/low_level_zero_plugin.py
@@ -1,111 +1,233 @@
-import random
-import warnings
-from typing import Callable, List, Optional, Tuple, Union
+import logging
+import os
+from functools import partial
+from pathlib import Path
+from types import MethodType
+from typing import Callable, Iterator, List, Optional, Tuple
-import numpy as np
import torch
-import torch.distributed as dist
import torch.nn as nn
-from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
-from torch.utils.data.distributed import DistributedSampler
-from colossalai.checkpoint_io import CheckpointIO
-from colossalai.interface import ModelWrapper, OptimizerWrapper
+from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
+from colossalai.checkpoint_io.utils import (
+ get_optimizer_base_filenames,
+ get_shard_filename,
+ load_param_groups_into_optimizer,
+ load_shard_state_dict,
+ load_states_into_optimizer,
+ save_param_groups,
+ save_state_dict,
+ sharded_optimizer_loading_epilogue,
+)
+from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
-from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
+from colossalai.zero import LowLevelZeroOptimizer
-from .plugin_base import Plugin
+from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO
-__all__ = ['LowLevelZeroPlugin']
+__all__ = ["LowLevelZeroPlugin"]
-def _convert_to_fp16(x):
+def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
- return x.half()
+ return x.to(dtype)
return x
-class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
-
- def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
- """
- Save optimizer to checkpoint but only on master process.
- """
- # TODO(ver217): optimizer state dict is sharded
- super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
+SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"]
-class LowLevelZeroModel(ModelWrapper):
-
- def __init__(self, module: nn.Module, stage: int, precision: str) -> None:
+class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
+ def __init__(self, module: nn.Module, precision: str) -> None:
super().__init__(module)
- self.convert_inputs = (precision == 'fp16')
- module = zero_model_wrapper(module, zero_stage=stage)
- if precision == 'fp16':
- module = module.half()
+ self.dtype = None
+ if precision == "fp16":
+ self.dtype = torch.float16
+ elif precision == "bf16":
+ self.dtype = torch.bfloat16
+ if self.dtype is not None:
+ module = module.to(self.dtype)
module = module.to(get_current_device())
self.module = module
+ self.convert_fn = None
+ if self.dtype is not None:
+ self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
def forward(self, *args, **kwargs):
- if self.convert_inputs:
- args = tree_map(_convert_to_fp16, args)
- kwargs = tree_map(_convert_to_fp16, kwargs)
+ if self.convert_fn is not None:
+ args = tree_map(self.convert_fn, args)
+ kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)
-class LowLevelZeroOptimizer(OptimizerWrapper):
-
- def __init__(self,
- module: nn.Module,
- optimizer: Optimizer,
- zero_optim_config: dict,
- optim_kwargs: dict,
- verbose: bool = False) -> None:
- optimizer = zero_optim_wrapper(module,
- optimizer,
- optim_config=zero_optim_config,
- **optim_kwargs,
- verbose=verbose)
- super().__init__(optimizer)
-
- def backward(self, loss: Tensor, *args, **kwargs):
- self.optim.backward(loss)
-
- def clip_grad_by_norm(self,
- max_norm: Union[float, int],
- norm_type: Union[float, int] = 2,
- error_if_nonfinite: bool = False,
- *args,
- **kwargs) -> Tensor:
- warnings.warn(f'LowLevelZero controls grad clipping by itself, so you should not use clip_grad_by_norm')
+class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
+ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
+ """Save optimizer to checkpoint but only on master process.
- def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
- raise NotImplementedError('LowLevelZero does not support clip_grad_by_value')
+ Args:
+ optimizer (OptimizerWrapper): Optimizer to save state_dict
+ checkpoint (str): Path to save checkpoint
+ gather_dtensor (bool): Whether to gather_dtensor, not used
+ """
+ assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
+ # the `state_dict` in LowLevelZeroOptimizer has communication
+ # if only the master rank collect state_dict and save,
+ # the communication on each rank would not match
+ state_dict = optimizer.state_dict()
+ if self.coordinator.is_master():
+ save_state_dict(state_dict, checkpoint, use_safetensors=False)
+
+ def save_sharded_optimizer(
+ self,
+ optimizer: OptimizerWrapper,
+ checkpoint: str,
+ gather_dtensor: bool = False,
+ prefix: str = None,
+ size_per_shard: int = 1024,
+ ):
+ """
+ Save sharded Zero-optimizer checkpoint under the given checkpointing path.
+ The following files will be created under the path:
+ - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
+ - A group file (pytorch_optim_group.bin) recording information of param_groups
+ - Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
+ Args:
+ optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
+ checkpoint (str): Path to save optimizer state_dict
+ gather_dtensor (bool): Whether to gather_dtensor, not used
+ prefix (str): Perfix of file to save
+ size_per_shard (int): Max file size of each file that store state tensors
+ """
+ assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
+ if os.path.isfile(checkpoint):
+ logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
+ return
+
+ Path(checkpoint).mkdir(parents=True, exist_ok=True)
+
+ # state_dict only provide only 'param_groups'
+ state_dict = optimizer.optim.state_dict()
+ # state shard would be handled by the low-level zero optimizer
+ sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard)
+
+ # Preparing file paths and index file.
+ states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
+ index_file = CheckpointIndexFile(checkpoint)
+
+ # Store the information of param groups to param_group_file.
+ index_file.append_meta_data("param_groups", param_group_file)
+ group_file_path = os.path.join(checkpoint, param_group_file)
+ save_param_groups(state_dict, group_file_path)
+
+ # Save shards of optimizer states.
+ total_size = 0
+ for idx, shard_pair in enumerate(sharded_state):
+ shard, current_size = shard_pair
+ shard_file = get_shard_filename(states_name, idx)
+ total_size = total_size + current_size
+ for param_id in shard.keys():
+ index_file.append_weight_map(str(param_id), shard_file)
+
+ checkpoint_file_path = os.path.join(checkpoint, shard_file)
+ if self.coordinator.is_master():
+ save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
+
+ # Wrap up index file.
+ index_file.append_meta_data("total_size", total_size)
+ if self.coordinator.is_master():
+ index_file.write_index_file(save_index_file)
+ logging.info(
+ f"The optimizer is going to be split to checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
+
+ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
+ """Load sharded optimizer with the given path to index file.
-class LowLevelZeroPlugin(Plugin):
+ Args:
+ optimizer (OptimizerWrapper): Optimizer to load state_dict
+ index_file_path (str): Path to the index file
+ prefix (str): Not used.
+ """
+ assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before Loading!"
+ optimizer = optimizer.unwrap()
+
+ # Read checkpoint index file.
+ ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
+
+ # Load param_groups
+ param_group_path = ckpt_index_file.get_param_group_filename()
+ if param_group_path is None:
+ raise RuntimeError(
+ f"Invalid index file path {index_file_path} for an optimizer. \
+ Lacking param group file under current directory."
+ )
+ id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
+
+ checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
+
+ for shard_file in checkpoint_files:
+ state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
+ # shard state dict
+ for param_idx, state in state_dict.items():
+ for k, v in state.items():
+ if isinstance(v, torch.Tensor) and k != "step":
+ padding_size = (
+ self.coordinator.world_size - v.numel() % self.coordinator.world_size
+ ) % self.coordinator.world_size
+ with torch.no_grad():
+ v = v.flatten()
+ if padding_size > 0:
+ v = torch.nn.functional.pad(v, [0, padding_size])
+ v_list = v.split(v.numel() // self.coordinator.world_size)
+ state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
+ load_states_into_optimizer(optimizer, state_dict, id_map)
+ sharded_optimizer_loading_epilogue(optimizer)
+
+ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
+ assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
+ super().load_unsharded_model(model, checkpoint, strict)
+ model.update_master_params()
+
+ def load_sharded_model(
+ self,
+ model: ModelWrapper,
+ checkpoint_index_file: Path,
+ strict: bool = False,
+ use_safetensors: bool = False,
+ load_sub_module: bool = True,
+ ):
+ assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
+ super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
+ model.update_master_params()
+
+
+class LowLevelZeroPlugin(DPPluginBase):
"""
Plugin for low level zero.
- Example:
- >>> from colossalai.booster import Booster
- >>> from colossalai.booster.plugin import LowLevelZeroPlugin
- >>>
- >>> model, train_dataset, optimizer, criterion = ...
- >>> plugin = LowLevelZeroPlugin()
+ ```python
+ from colossalai.booster import Booster
+ from colossalai.booster.plugin import LowLevelZeroPlugin
- >>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
- >>> booster = Booster(plugin=plugin)
- >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
+ model, train_dataset, optimizer, criterion = ...
+ plugin = LowLevelZeroPlugin()
+
+ train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
+ booster = Booster(plugin=plugin)
+ model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
+ ```
Args:
- strage (int, optional): ZeRO stage. Defaults to 1.
- precision (str, optional): precision. Support 'fp16' and 'fp32'. Defaults to 'fp16'.
+ stage (int, optional): ZeRO stage. Defaults to 1.
+ precision (str, optional): precision. Support 'fp16', 'bf16' and 'fp32'. Defaults to 'fp16'.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
@@ -126,7 +248,7 @@ class LowLevelZeroPlugin(Plugin):
def __init__(
self,
stage: int = 1,
- precision: str = 'fp16',
+ precision: str = "fp16",
initial_scale: float = 2**32,
min_scale: float = 1,
growth_factor: float = 2,
@@ -142,113 +264,64 @@ class LowLevelZeroPlugin(Plugin):
cpu_offload: bool = False,
verbose: bool = False,
) -> None:
-
- assert dist.is_initialized(
- ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment'
- assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
- assert precision in ('fp16', 'fp32'), f'LowLevelZeroPlugin only supports fp16/fp32 training'
-
- self.rank = dist.get_rank()
- self.world_size = dist.get_world_size()
-
+ super().__init__()
+ assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
+ assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training"
+ assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now"
self.stage = stage
self.precision = precision
- self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
- communication_dtype=communication_dtype,
- overlap_communication=overlap_communication,
- cpu_offload=cpu_offload)
- self.optim_kwargs = dict(initial_scale=initial_scale,
- growth_factor=growth_factor,
- backoff_factor=backoff_factor,
- growth_interval=growth_interval,
- hysteresis=hysteresis,
- min_scale=min_scale,
- max_scale=max_scale,
- max_norm=max_norm,
- norm_type=norm_type)
+ self.zero_optim_kwargs = dict(
+ initial_scale=initial_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ hysteresis=hysteresis,
+ min_scale=min_scale,
+ max_scale=max_scale,
+ clip_grad_norm=max_norm,
+ reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
+ communication_dtype=communication_dtype,
+ overlap_communication=overlap_communication,
+ cpu_offload=cpu_offload,
+ partition_grad=(stage == 2),
+ )
self.verbose = verbose
+ # set class name with stage, for better error message
+ setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
+
def support_no_sync(self) -> bool:
- return False
+ return self.stage == 1
def control_precision(self) -> bool:
return True
def supported_precisions(self) -> List[str]:
- return ['fp16', 'fp32']
+ return SUPPORTED_PRECISION
def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
- return ['cuda']
-
- def prepare_train_dataloader(self,
- dataset,
- batch_size,
- shuffle=False,
- seed=1024,
- drop_last=False,
- pin_memory=False,
- num_workers=0,
- **kwargs):
- r"""
- Prepare a dataloader for distributed training. The dataloader will be wrapped by
- `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
-
- Note:
- 1. Evaluation datasets should not be passed to this function.
-
- Args:
- dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
- shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
- seed (int, optional): Random worker seed for sampling, defaults to 1024.
- add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
- drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
- is not divisible by the batch size. If False and the size of dataset is not divisible by
- the batch size, then the last batch will be smaller, defaults to False.
- pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
- num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
- kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
- `DataLoader `_.
-
- Returns:
- :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
- """
- _kwargs = kwargs.copy()
- sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
-
- # Deterministic dataloader
- def seed_worker(worker_id):
- worker_seed = seed
- np.random.seed(worker_seed)
- torch.manual_seed(worker_seed)
- random.seed(worker_seed)
-
- return DataLoader(dataset,
- batch_size=batch_size,
- sampler=sampler,
- worker_init_fn=seed_worker,
- drop_last=drop_last,
- pin_memory=pin_memory,
- num_workers=num_workers,
- **_kwargs)
+ return ["cuda"]
def configure(
self,
model: nn.Module,
- optimizer: Optimizer,
- criterion: Callable = None,
- dataloader: DataLoader = None,
- lr_scheduler: LRScheduler = None,
- ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
-
+ optimizer: Optional[Optimizer] = None,
+ criterion: Optional[Callable] = None,
+ dataloader: Optional[DataLoader] = None,
+ lr_scheduler: Optional[LRScheduler] = None,
+ ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper):
- model = LowLevelZeroModel(model, self.stage, self.precision)
+ model = LowLevelZeroModel(model, self.precision)
- if not isinstance(optimizer, OptimizerWrapper):
- optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
- self.verbose)
+ if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
+ optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
+ optimizer, **self.zero_optim_kwargs, verbose=self.verbose
+ )
+ # inject update_master_params
+ model.update_master_params = MethodType(optimizer.update_master_params, model)
return model, optimizer, criterion, dataloader, lr_scheduler
@@ -257,3 +330,7 @@ class LowLevelZeroPlugin(Plugin):
def get_checkpoint_io(self) -> CheckpointIO:
return LowLevelZeroCheckpointIO()
+
+ def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
+ assert isinstance(optimizer, LowLevelZeroOptimizer)
+ return optimizer.optim.no_sync()
diff --git a/colossalai/booster/plugin/plugin_base.py b/colossalai/booster/plugin/plugin_base.py
index 7a222022c1b264585b03ea4c402e580a909c4af7..4e570cbe8abc613bf88282d4f478f5698da211ff 100644
--- a/colossalai/booster/plugin/plugin_base.py
+++ b/colossalai/booster/plugin/plugin_base.py
@@ -1,19 +1,18 @@
from abc import ABC, abstractmethod
-from typing import Callable, List, Tuple, Union
+from typing import Callable, Iterator, List, Optional, Tuple
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
-from torch.utils.data import DataLoader
+from torch.utils.data import DataLoader, Dataset
from colossalai.checkpoint_io import CheckpointIO
from colossalai.interface import OptimizerWrapper
-__all__ = ['Plugin']
+__all__ = ["Plugin"]
class Plugin(ABC):
-
@abstractmethod
def supported_devices(self) -> List[str]:
pass
@@ -38,11 +37,11 @@ class Plugin(ABC):
def configure(
self,
model: nn.Module,
- optimizer: Optimizer,
- criterion: Callable = None,
- dataloader: DataLoader = None,
- lr_scheduler: LRScheduler = None,
- ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
+ optimizer: Optional[Optimizer] = None,
+ criterion: Optional[Callable] = None,
+ dataloader: Optional[DataLoader] = None,
+ lr_scheduler: Optional[LRScheduler] = None,
+ ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# implement this method
pass
@@ -51,11 +50,31 @@ class Plugin(ABC):
"""
Whether the plugin controls the checkpoint io
"""
- pass
@abstractmethod
def get_checkpoint_io(self) -> CheckpointIO:
"""
Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True.
"""
- pass
+
+ @abstractmethod
+ def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
+ """
+ Context manager to disable gradient synchronization.
+ """
+
+ @abstractmethod
+ def prepare_dataloader(
+ self,
+ dataset: Dataset,
+ batch_size: int,
+ shuffle: bool = False,
+ seed: int = 1024,
+ drop_last: bool = False,
+ pin_memory: bool = False,
+ num_workers: int = 0,
+ **kwargs,
+ ):
+ """Prepare a dataloader for distributed training. The dataloader will be wrapped by
+ `torch.utils.data.DataLoader`
+ """
diff --git a/colossalai/booster/plugin/pp_plugin_base.py b/colossalai/booster/plugin/pp_plugin_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d91eb95b4097265f7e89e0598b6592b2da58848
--- /dev/null
+++ b/colossalai/booster/plugin/pp_plugin_base.py
@@ -0,0 +1,22 @@
+from abc import abstractmethod
+from typing import Any, Callable, Iterator, Optional
+
+import torch
+
+from colossalai.interface import ModelWrapper, OptimizerWrapper
+
+from .plugin_base import Plugin
+
+
+class PipelinePluginBase(Plugin):
+ @abstractmethod
+ def execute_pipeline(
+ self,
+ data_iter: Iterator,
+ model: ModelWrapper,
+ criterion: Callable[[Any, Any], torch.Tensor],
+ optimizer: Optional[OptimizerWrapper] = None,
+ return_loss: bool = True,
+ return_outputs: bool = False,
+ ) -> dict:
+ pass
diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py
index c5e310c7e7695893266784e9b90018b0d373f6f2..738634473dbc669e377dca614c2e890ee603baa9 100644
--- a/colossalai/booster/plugin/torch_ddp_plugin.py
+++ b/colossalai/booster/plugin/torch_ddp_plugin.py
@@ -1,50 +1,52 @@
-import random
-from typing import Callable, List, Tuple, Union
+from typing import Callable, Iterator, List, Optional, Tuple
-import numpy as np
-import torch
-import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
-from torch.utils.data.distributed import DistributedSampler
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
-from .plugin_base import Plugin
+from .dp_plugin_base import DPPluginBase
-__all__ = ['TorchDDPPlugin']
+__all__ = ["TorchDDPPlugin"]
class TorchDDPCheckpointIO(GeneralCheckpointIO):
-
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
- def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
+ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
"""
- Load model from checkpoint with automatic unwrapping.
+ Load model from checkpoint.
"""
- # the model should be unwrapped in self.load_model via ModelWrapper.unwrap
- return super().load_unsharded_model(model, checkpoint, strict=strict)
+ assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
+ super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)
- def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
+ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
"""
- # the model should be unwrapped in self.load_model via ModelWrapper.unwrap
+ assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master():
- super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
+ super().save_unsharded_model(model.unwrap(), checkpoint, gather_dtensor, use_safetensors)
+
+ def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
+ """
+ Load optimizer from checkpoint.
+ """
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
+ super().load_unsharded_optimizer(optimizer, checkpoint)
- def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
+ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if self.coordinator.is_master():
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
@@ -55,9 +57,67 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
+ def save_sharded_model(
+ self,
+ model: ModelWrapper,
+ checkpoint_path: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ max_shard_size: int = 1024,
+ use_safetensors: bool = False,
+ ):
+ """
+ Save model to checkpoint but only on master process.
+ """
+ assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
+ if self.coordinator.is_master():
+ super().save_sharded_model(
+ model.unwrap(), checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
+ )
-class TorchDDPModel(ModelWrapper):
+ def load_sharded_model(
+ self,
+ model: ModelWrapper,
+ checkpoint_index_file: str,
+ strict: bool = False,
+ use_safetensors: bool = False,
+ load_sub_module: bool = True,
+ ):
+ """
+ Load model from sharded checkpoint.
+ """
+ assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
+ super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module)
+
+ def save_sharded_optimizer(
+ self,
+ optimizer: OptimizerWrapper,
+ checkpoint: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024,
+ ):
+ """
+ Save optimizer to sharded checkpoint but only on master process.
+ """
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
+ if self.coordinator.is_master():
+ super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard)
+
+ def load_sharded_optimizer(
+ self,
+ optimizer: Optimizer,
+ index_file_path: str,
+ prefix: Optional[str] = None,
+ ):
+ """
+ Load optimizer from sharded checkpoint.
+ """
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
+ super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
+
+class TorchDDPModel(ModelWrapper):
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
super().__init__(module)
self.module = DDP(module, *args, **kwargs)
@@ -66,20 +126,21 @@ class TorchDDPModel(ModelWrapper):
return self.module.module
-class TorchDDPPlugin(Plugin):
+class TorchDDPPlugin(DPPluginBase):
"""
Plugin for PyTorch DDP.
- Example:
- >>> from colossalai.booster import Booster
- >>> from colossalai.booster.plugin import TorchDDPPlugin
- >>>
- >>> model, train_dataset, optimizer, criterion = ...
- >>> plugin = TorchDDPPlugin()
+ ```python
+ from colossalai.booster import Booster
+ from colossalai.booster.plugin import TorchDDPPlugin
- >>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
- >>> booster = Booster(plugin=plugin)
- >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
+ model, train_dataset, optimizer, criterion = ...
+ plugin = TorchDDPPlugin()
+
+ train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
+ booster = Booster(plugin=plugin)
+ model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
+ ```
Args:
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Defaults to True.
@@ -90,24 +151,24 @@ class TorchDDPPlugin(Plugin):
static_graph (bool, optional): Whether to use static graph. Defaults to False.
"""
- def __init__(self,
- broadcast_buffers: bool = True,
- bucket_cap_mb: int = 25,
- find_unused_parameters: bool = False,
- check_reduction: bool = False,
- gradient_as_bucket_view: bool = False,
- static_graph: bool = False) -> None:
-
- assert dist.is_initialized(
- ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment'
- self.rank = dist.get_rank()
- self.world_size = dist.get_world_size()
- self.ddp_kwargs = dict(broadcast_buffers=broadcast_buffers,
- bucket_cap_mb=bucket_cap_mb,
- find_unused_parameters=find_unused_parameters,
- check_reduction=check_reduction,
- gradient_as_bucket_view=gradient_as_bucket_view,
- static_graph=static_graph)
+ def __init__(
+ self,
+ broadcast_buffers: bool = True,
+ bucket_cap_mb: int = 25,
+ find_unused_parameters: bool = False,
+ check_reduction: bool = False,
+ gradient_as_bucket_view: bool = False,
+ static_graph: bool = False,
+ ) -> None:
+ super().__init__()
+ self.ddp_kwargs = dict(
+ broadcast_buffers=broadcast_buffers,
+ bucket_cap_mb=bucket_cap_mb,
+ find_unused_parameters=find_unused_parameters,
+ check_reduction=check_reduction,
+ gradient_as_bucket_view=gradient_as_bucket_view,
+ static_graph=static_graph,
+ )
def support_no_sync(self) -> bool:
return True
@@ -116,73 +177,22 @@ class TorchDDPPlugin(Plugin):
return False
def supported_precisions(self) -> List[str]:
- return ['fp16', 'fp16_apex', 'bf16', 'fp8']
+ return ["fp16", "fp16_apex", "bf16", "fp8"]
def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
- return ['cuda']
-
- def prepare_train_dataloader(self,
- dataset,
- batch_size,
- shuffle=False,
- seed=1024,
- drop_last=False,
- pin_memory=False,
- num_workers=0,
- **kwargs):
- r"""
- Prepare a dataloader for distributed training. The dataloader will be wrapped by
- `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
-
- Note:
- 1. Evaluation datasets should not be passed to this function.
-
- Args:
- dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
- shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
- seed (int, optional): Random worker seed for sampling, defaults to 1024.
- add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
- drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
- is not divisible by the batch size. If False and the size of dataset is not divisible by
- the batch size, then the last batch will be smaller, defaults to False.
- pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
- num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
- kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
- `DataLoader `_.
-
- Returns:
- :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
- """
- _kwargs = kwargs.copy()
- sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
-
- # Deterministic dataloader
- def seed_worker(worker_id):
- worker_seed = seed
- np.random.seed(worker_seed)
- torch.manual_seed(worker_seed)
- random.seed(worker_seed)
-
- return DataLoader(dataset,
- batch_size=batch_size,
- sampler=sampler,
- worker_init_fn=seed_worker,
- drop_last=drop_last,
- pin_memory=pin_memory,
- num_workers=num_workers,
- **_kwargs)
+ return ["cuda"]
def configure(
self,
model: nn.Module,
- optimizer: Optimizer,
- criterion: Callable = None,
- dataloader: DataLoader = None,
- lr_scheduler: LRScheduler = None,
- ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
+ optimizer: Optional[Optimizer] = None,
+ criterion: Optional[Callable] = None,
+ dataloader: Optional[DataLoader] = None,
+ lr_scheduler: Optional[LRScheduler] = None,
+ ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# cast model to cuda
model = model.cuda()
@@ -192,7 +202,7 @@ class TorchDDPPlugin(Plugin):
# wrap the model with PyTorch DDP
model = TorchDDPModel(model, **self.ddp_kwargs)
- if not isinstance(optimizer, OptimizerWrapper):
+ if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer)
return model, optimizer, criterion, dataloader, lr_scheduler
@@ -202,3 +212,7 @@ class TorchDDPPlugin(Plugin):
def get_checkpoint_io(self) -> CheckpointIO:
return TorchDDPCheckpointIO()
+
+ def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
+ assert isinstance(model, TorchDDPModel), "Model is not boosted by TorchDDPPlugin."
+ return model.module.no_sync()
diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ea7593a5cc548be469fddd32ada454cdea1f118
--- /dev/null
+++ b/colossalai/booster/plugin/torch_fsdp_plugin.py
@@ -0,0 +1,237 @@
+import warnings
+from pathlib import Path
+from typing import Callable, Iterable, Iterator, List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from packaging import version
+from torch.distributed import ProcessGroup
+
+if version.parse(torch.__version__) >= version.parse("1.12.0"):
+ from torch.distributed.fsdp import FullStateDictConfig
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+ from torch.distributed.fsdp import StateDictType
+ from torch.distributed.fsdp.fully_sharded_data_parallel import (
+ BackwardPrefetch,
+ CPUOffload,
+ FullStateDictConfig,
+ MixedPrecision,
+ ShardingStrategy,
+ )
+else:
+ raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
+
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+from torch.utils.data import DataLoader
+
+from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils
+from colossalai.cluster import DistCoordinator
+from colossalai.interface import ModelWrapper, OptimizerWrapper
+
+from .dp_plugin_base import DPPluginBase
+
+__all__ = ["TorchFSDPPlugin"]
+
+
+class TorchFSDPCheckpointIO(GeneralCheckpointIO):
+ def __init__(self) -> None:
+ super().__init__()
+ self.coordinator = DistCoordinator()
+
+ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool):
+ assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
+ model = model.unwrap()
+ checkpoint = utils.load_state_dict(checkpoint)
+ model.load_state_dict(checkpoint)
+
+ def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path):
+ assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!"
+ checkpoint = utils.load_state_dict(checkpoint)
+ fsdp_model = optimizer.unwrap_model()
+ sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
+ optimizer.load_state_dict(sharded_osd)
+
+ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
+ """
+ Save model to checkpoint but only on master process.
+ """
+ assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!"
+ model = model.unwrap()
+ cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
+ with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
+ full_model_state = model.state_dict()
+ utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
+
+ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
+ """
+ Save optimizer to checkpoint but only on master process.
+ """
+ assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
+ fsdp_model = optimizer.unwrap_model()
+ full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
+ utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
+
+ def save_sharded_model(
+ self,
+ model: nn.Module,
+ checkpoint: str,
+ gather_dtensor: bool,
+ prefix: Optional[str],
+ size_per_shard: int,
+ use_safetensors: bool,
+ ):
+ """
+ Save model to checkpoint but only on master process.
+ """
+ raise NotImplementedError("Sharded model checkpoint is not supported yet.")
+
+ def load_sharded_model(
+ self,
+ model: nn.Module,
+ checkpoint_index_file: Path,
+ strict: bool = False,
+ use_safetensors: bool = False,
+ load_sub_module: bool = True,
+ ):
+ """
+ Load model to checkpoint but only on master process.
+ """
+ raise NotImplementedError("Sharded model checkpoint is not supported yet.")
+
+ def save_sharded_optimizer(
+ self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int
+ ):
+ """
+ Save optimizer to checkpoint but only on master process.
+ """
+ raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
+
+ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int):
+ """
+ Load optimizer to checkpoint but only on master process.
+ """
+ raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
+
+ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
+ """
+ Save model to checkpoint but only on master process.
+ """
+ if self.coordinator.is_master():
+ super().save_lr_scheduler(lr_scheduler, checkpoint)
+
+
+class TorchFSDPModel(ModelWrapper):
+ def __init__(self, module: nn.Module, *args, **kwargs) -> None:
+ super().__init__(module)
+ self.module = FSDP(module, *args, **kwargs)
+
+ def unwrap(self):
+ return self.module
+
+
+class FSDPOptimizerWrapper(OptimizerWrapper):
+ def __init__(self, optimizer: Optimizer, model: nn.Module):
+ self.model = model
+ super().__init__(optimizer)
+
+ def unwrap_model(self) -> nn.Module:
+ return self.model
+
+
+class TorchFSDPPlugin(DPPluginBase):
+ """
+ Plugin for PyTorch FSDP.
+
+ ```python
+ from colossalai.booster import Booster
+ from colossalai.booster.plugin import TorchFSDPPlugin
+
+ model, train_dataset, optimizer, criterion = ...
+ plugin = TorchFSDPPlugin()
+
+ train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
+ booster = Booster(plugin=plugin)
+ model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
+ ```
+
+ Args:
+ See https://pytorch.org/docs/stable/fsdp.html for details.
+ """
+
+ if version.parse(torch.__version__) >= version.parse("1.12.0"):
+
+ def __init__(
+ self,
+ process_group: Optional[ProcessGroup] = None,
+ sharding_strategy: Optional[ShardingStrategy] = None,
+ cpu_offload: Optional[CPUOffload] = None,
+ auto_wrap_policy: Optional[Callable] = None,
+ backward_prefetch: Optional[BackwardPrefetch] = None,
+ mixed_precision: Optional[MixedPrecision] = None,
+ ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
+ param_init_fn: Optional[Callable[[nn.Module], None]] = None,
+ sync_module_states: bool = False,
+ ):
+ super().__init__()
+ self.fsdp_kwargs = dict(
+ process_group=process_group,
+ sharding_strategy=sharding_strategy,
+ cpu_offload=cpu_offload,
+ auto_wrap_policy=auto_wrap_policy,
+ backward_prefetch=backward_prefetch,
+ mixed_precision=mixed_precision,
+ ignored_modules=ignored_modules,
+ param_init_fn=param_init_fn,
+ sync_module_states=sync_module_states,
+ )
+
+ else:
+ raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
+
+ def support_no_sync(self) -> bool:
+ False
+
+ def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
+ raise NotImplementedError("Torch fsdp no_sync func not supported yet.")
+
+ def control_precision(self) -> bool:
+ return True
+
+ def supported_precisions(self) -> List[str]:
+ return ["fp16", "bf16"]
+
+ def control_device(self) -> bool:
+ return True
+
+ def supported_devices(self) -> List[str]:
+ return ["cuda"]
+
+ def configure(
+ self,
+ model: nn.Module,
+ optimizer: Optional[Optimizer] = None,
+ criterion: Optional[Callable] = None,
+ dataloader: Optional[DataLoader] = None,
+ lr_scheduler: Optional[LRScheduler] = None,
+ ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
+ # wrap the model with PyTorch FSDP
+ fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
+
+ if optimizer is not None:
+ if len(optimizer.param_groups) > 1:
+ warnings.warn(
+ "TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used."
+ )
+ optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
+
+ if not isinstance(optimizer, FSDPOptimizerWrapper):
+ optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model)
+
+ return fsdp_model, optimizer, criterion, dataloader, lr_scheduler
+
+ def control_checkpoint_io(self) -> bool:
+ return True
+
+ def get_checkpoint_io(self) -> CheckpointIO:
+ return TorchFSDPCheckpointIO()
diff --git a/colossalai/builder/__init__.py b/colossalai/builder/__init__.py
deleted file mode 100644
index cf09e1e7a31a15e979c5358c5da27683a0ccb2f9..0000000000000000000000000000000000000000
--- a/colossalai/builder/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .builder import build_from_config, build_from_registry, build_gradient_handler
-
-__all__ = ['build_gradient_handler', 'build_from_config', 'build_from_registry']
diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py
index c25048e25754eb6fc55db1b1de4ac7b21d05bda3..19b61730bded60230d80cbca434f9e06c0fd66c5 100644
--- a/colossalai/checkpoint_io/__init__.py
+++ b/colossalai/checkpoint_io/__init__.py
@@ -1,5 +1,6 @@
from .checkpoint_io_base import CheckpointIO
from .general_checkpoint_io import GeneralCheckpointIO
+from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
from .index_file import CheckpointIndexFile
-__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO']
+__all__ = ["CheckpointIO", "CheckpointIndexFile", "GeneralCheckpointIO", "HybridParallelCheckpointIO"]
diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py
index cb853559c48c1a6ceb41231e312a2480fd68bd97..780117598e183a98b20baa593b8f80179b63d360 100644
--- a/colossalai/checkpoint_io/checkpoint_io_base.py
+++ b/colossalai/checkpoint_io/checkpoint_io_base.py
@@ -1,7 +1,6 @@
from abc import ABC, abstractmethod
from pathlib import Path
-from typing import Union
-from typing import Optional
+from typing import Optional, Union
import torch
import torch.nn as nn
@@ -12,7 +11,7 @@ from colossalai.interface import ModelWrapper
from .utils import has_index_file
-__all__ = ['CheckpointIO']
+__all__ = ["CheckpointIO"]
class CheckpointIO(ABC):
@@ -62,10 +61,9 @@ class CheckpointIO(ABC):
# ======================================
# Public methods
# ======================================
- def load_model(self,
- model: Union[nn.Module, ModelWrapper],
- checkpoint: str,
- strict: bool = True) -> Union[nn.Module, ModelWrapper]:
+ def load_model(
+ self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True
+ ) -> Union[nn.Module, ModelWrapper]:
"""
Load model from checkpoint.
@@ -84,15 +82,11 @@ class CheckpointIO(ABC):
# containing no distributed tensors, dtensor -> full tensor conversion
# should be done offline via our CLI
# the existence of index file means it is a sharded checkpoint
- ckpt_path = Path(checkpoint)
index_file_exists, index_file_path = has_index_file(checkpoint)
# return the origin model instead of the unwrapped model
origin_model = model
- if isinstance(model, ModelWrapper):
- model = model.unwrap()
-
if index_file_exists:
self.load_sharded_model(model, index_file_path, strict)
else:
@@ -100,14 +94,16 @@ class CheckpointIO(ABC):
return origin_model
- def save_model(self,
- model: Union[nn.Module, ModelWrapper],
- checkpoint: str,
- shard: bool = False,
- gather_dtensor: bool = True,
- variant: str = None,
- size_per_shard: int = 1024,
- use_safetensors: bool = False):
+ def save_model(
+ self,
+ model: Union[nn.Module, ModelWrapper],
+ checkpoint: str,
+ shard: bool = False,
+ gather_dtensor: bool = True,
+ prefix: str = None,
+ size_per_shard: int = 1024,
+ use_safetensors: bool = False,
+ ):
"""
Save model to checkpoint.
@@ -130,46 +126,49 @@ class CheckpointIO(ABC):
multiple files. The model shards will be specified by a `model.index.json` file. When shard = True, please ensure
that the checkpoint path is a directory path instead of a file path.
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
- variant (str): If specified, weights are saved in the format pytorch_model..bin. Default: None.
+ prefix (str): If specified, weights are saved in the format pytorch_model..bin. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
"""
- if isinstance(model, ModelWrapper):
- model = model.unwrap()
-
if shard:
- self.save_sharded_model(model, checkpoint, gather_dtensor, variant, size_per_shard, use_safetensors)
+ self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
else:
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
- def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
+ def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
"""
Load optimizer from checkpoint.
Args:
optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
+ prefix (str, optional): A prefix added to parameter and buffer
+ names to compose the keys in state_dict. Defaults to None.
+ size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
+
index_file_exists, index_file_path = has_index_file(checkpoint)
if Path(checkpoint).is_dir() and not index_file_exists:
# if the checkpoint is a directory and there is no index file, raise error
- raise ValueError(f'Cannot find index file in {checkpoint}')
+ raise ValueError(f"Cannot find index file in {checkpoint}")
if index_file_exists:
# the existence of index file means it is a sharded checkpoint
- self.load_sharded_optimizer(optimizer, index_file_path)
+ self.load_sharded_optimizer(optimizer, index_file_path, prefix)
else:
self.load_unsharded_optimizer(optimizer, checkpoint)
- def save_optimizer(self,
- optimizer: Optimizer,
- checkpoint: str,
- shard: bool = False,
- gather_dtensor=True,
- prefix: str = None,
- size_per_shard: int = 1024):
+ def save_optimizer(
+ self,
+ optimizer: Optimizer,
+ checkpoint: str,
+ shard: bool = False,
+ gather_dtensor=True,
+ prefix: str = None,
+ size_per_shard: int = 1024,
+ ):
"""
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
@@ -185,6 +184,7 @@ class CheckpointIO(ABC):
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
"""
+
if shard:
self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
else:
@@ -204,7 +204,6 @@ class CheckpointIO(ABC):
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
"""
- pass
@abstractmethod
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
@@ -217,11 +216,17 @@ class CheckpointIO(ABC):
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
"""
- pass
@abstractmethod
- def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
- size_per_shard: int, use_safetensors: bool):
+ def save_sharded_model(
+ self,
+ model: nn.Module,
+ checkpoint: str,
+ gather_dtensor: bool,
+ prefix: Optional[str],
+ size_per_shard: int,
+ use_safetensors: bool,
+ ):
"""
Save model to sharded checkpoint.
@@ -233,7 +238,6 @@ class CheckpointIO(ABC):
size_per_shard (int): size per shard in MB.
use_safetensors (bool): whether to use safe tensors.
"""
- pass
@abstractmethod
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
@@ -246,14 +250,13 @@ class CheckpointIO(ABC):
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
use_safetensors (bool): whether to use safe tensors.
"""
- pass
# ========================================================
# Abstract methods for optimizer loading/saving implementation
# ========================================================
@abstractmethod
- def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
+ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
"""
Load optimizer from sharded checkpoint.
@@ -261,9 +264,7 @@ class CheckpointIO(ABC):
optimizer (Optimizer): optimizer to be loaded.
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
prefix (str): prefix for the optimizer checkpoint.
- size_per_shard (int): size per shard in MB.
"""
- pass
@abstractmethod
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
@@ -274,11 +275,11 @@ class CheckpointIO(ABC):
optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
"""
- pass
@abstractmethod
- def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str,
- size_per_shard: int):
+ def save_sharded_optimizer(
+ self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
+ ):
"""
Save optimizer to sharded checkpoint.
@@ -289,7 +290,6 @@ class CheckpointIO(ABC):
prefix (str): prefix for the optimizer checkpoint.
size_per_shard (int): size per shard in MB.
"""
- pass
@abstractmethod
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
@@ -301,7 +301,6 @@ class CheckpointIO(ABC):
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
gather_dtensor (bool): whether to gather the distributed tensor to the first device.
"""
- pass
# ============================================
# methods for loading and saving lr scheduler
diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py
index bf584f45d045228c5d6d1e02b470c7696f5194db..a652d9b4538ec7b9270db87cf82ff988e1dde1c4 100644
--- a/colossalai/checkpoint_io/general_checkpoint_io.py
+++ b/colossalai/checkpoint_io/general_checkpoint_io.py
@@ -1,35 +1,41 @@
+import gc
+import logging
+import os
+from functools import reduce
from pathlib import Path
+from typing import Optional
import torch.nn as nn
from torch.optim import Optimizer
-import logging
-import os
-import json
-import gc
-from typing import Optional
from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
- has_index_file,
- load_state_dict,
- save_state_dict,
+ get_model_base_filenames,
+ get_optimizer_base_filenames,
is_safetensors_available,
- shard_checkpoint,
+ load_param_groups_into_optimizer,
load_shard_state_dict,
+ load_state_dict,
load_state_dict_into_model,
- add_variant
- )
-from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME
-
+ load_states_into_optimizer,
+ save_config_file,
+ save_param_groups,
+ save_state_dict,
+ save_state_dict_shards,
+ shard_model_checkpoint,
+ shard_optimizer_checkpoint,
+ sharded_optimizer_loading_epilogue,
+)
-__all__ = ['GeneralCheckpointIO']
+__all__ = ["GeneralCheckpointIO"]
class GeneralCheckpointIO(CheckpointIO):
"""
Checkpoint IO
"""
+
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
checkpoint = load_state_dict(checkpoint)
model.load_state_dict(checkpoint, strict=strict)
@@ -44,12 +50,30 @@ class GeneralCheckpointIO(CheckpointIO):
# save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors)
- def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
- raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
+ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
+ """
+ Load sharded optimizer with the given path to index file.
+ """
+
+ # Read checkpoint index file.
+ ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
- def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
- checkpoint = load_state_dict(checkpoint)
- optimizer.load_state_dict(checkpoint)
+ # Load param_groups
+ param_group_path = ckpt_index_file.get_param_group_filename()
+ if param_group_path is None:
+ raise RuntimeError(
+ f"Invalid index file path {index_file_path} for an optimizer. \
+ Lacking param group file under current directory."
+ )
+ id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
+
+ checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
+
+ for shard_file in checkpoint_files:
+ state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
+ load_states_into_optimizer(optimizer, state_dict, id_map)
+
+ sharded_optimizer_loading_epilogue(optimizer)
def save_sharded_optimizer(
self,
@@ -59,7 +83,56 @@ class GeneralCheckpointIO(CheckpointIO):
prefix: str,
size_per_shard: int,
):
- raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
+ """
+ Save sharded optimizer checkpoint under the given checkpointing path.
+ The following files will be created under the path:
+ - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
+ - A group file (pytorch_optim_group.bin) recording information of param_groups
+ - Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
+ """
+
+ if os.path.isfile(checkpoint):
+ logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
+ return
+
+ Path(checkpoint).mkdir(parents=True, exist_ok=True)
+
+ # Offload optimizer states. States are broken into shards within max_shard_size.
+ state_dict = optimizer.state_dict()
+ sharded_state = shard_optimizer_checkpoint(state_dict, max_shard_size=size_per_shard)
+
+ # Preparing file paths and index file.
+ states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
+ index_file = CheckpointIndexFile(checkpoint)
+
+ # Store the information of param groups to param_group_file.
+ index_file.append_meta_data("param_groups", param_group_file)
+ group_file_path = os.path.join(checkpoint, param_group_file)
+ save_param_groups(state_dict, group_file_path)
+
+ # Save shards of optimizer states.
+ # In general cases, is_master is set to True to get the right behavior.
+ total_size = save_state_dict_shards(
+ sharded_state_dict=sharded_state,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=True,
+ use_safetensors=False,
+ )
+
+ # Wrap up index file.
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ logging.info(
+ f"The optimizer is going to be split to checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
+
+ def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
+ checkpoint = load_state_dict(checkpoint)
+ optimizer.load_state_dict(checkpoint)
def save_unsharded_optimizer(
self,
@@ -70,45 +143,59 @@ class GeneralCheckpointIO(CheckpointIO):
# TODO(FrankLeeeee): handle distributed tensors
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
-
- def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False,
- variant: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False):
- """
+ def save_sharded_model(
+ self,
+ model: nn.Module,
+ checkpoint_path: str,
+ gather_dtensor: bool = False,
+ prefix: Optional[str] = None,
+ max_shard_size: int = 1024,
+ use_safetensors: bool = False,
+ ):
+ """
implement this method as it can be supported by Huggingface model,
save shard model, save model to multiple files
"""
if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
return
-
+
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
-
+
# shard checkpoint
state_dict = model.state_dict()
- weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
- weights_name = add_variant(weights_name, variant)
- shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
-
- # Save the model
- for shard_file, shard in shards.items():
- checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
- save_state_dict(shard, checkpoint_file_path, use_safetensors)
-
- # save index file
- save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
-
- save_index_file = os.path.join(checkpoint_path, add_variant(save_index_file, variant))
- with open(save_index_file, "w", encoding="utf-8") as f:
- content = json.dumps(index, indent=2, sort_keys=True) + "\n"
- f.write(content)
+ state_dict_shard = shard_model_checkpoint(state_dict, max_shard_size=max_shard_size)
+ weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
+ index_file = CheckpointIndexFile(checkpoint_path)
+
+ # Save shards of optimizer states.
+ # In general cases, is_master is set to True to get the right behavior.
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint_path,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=True,
+ use_safetensors=use_safetensors,
+ )
+
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ save_config_file(model, checkpoint_path, is_master=True)
logging.info(
- f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
- f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
+ f"The model is going to be split to checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
-
- def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False):
+ def load_sharded_model(
+ self,
+ model: nn.Module,
+ checkpoint_index_file: Path,
+ strict: bool = False,
+ use_safetensors: bool = False,
+ load_sub_module: bool = True,
+ ):
"""
load shard model, load model from multiple files
"""
@@ -118,21 +205,26 @@ class GeneralCheckpointIO(CheckpointIO):
if use_safetensors and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
-
+
# read checkpoint index file
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
- checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames()
- missing_keys = ckpt_index_file.get_all_param_names()
+ checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
+ missing_keys = []
for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
- load_state_dict_into_model(model, state_dict, missing_keys, strict)
+ load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)
del state_dict
gc.collect()
- if strict and len(missing_keys) > 0:
- error_msgs = 'Missing key(s) in state_dict: {}. '.format(
- ', '.join('"{}"'.format(k) for k in missing_keys))
- raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
- self.__class__.__name__, "\n\t".join(error_msgs)))
-
+ if strict:
+ remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
+ if len(remain_keys) > 0:
+ error_msgs = "Missing key(s) in state_dict: {}. ".format(
+ ", ".join('"{}"'.format(k) for k in missing_keys)
+ )
+ raise RuntimeError(
+ "Error(s) in loading state_dict for {}:\n\t{}".format(
+ self.__class__.__name__, "\n\t".join(error_msgs)
+ )
+ )
diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
new file mode 100644
index 0000000000000000000000000000000000000000..779ff42d75a1a905e2a08d456f73462fa1f2a1b6
--- /dev/null
+++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
@@ -0,0 +1,876 @@
+import copy
+import logging
+import os
+from pathlib import Path
+from shutil import rmtree
+from typing import Dict, Iterator, Optional, OrderedDict, Tuple
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.distributed import ProcessGroup
+from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+
+from colossalai.cluster import DistCoordinator
+from colossalai.interface import ModelWrapper, OptimizerWrapper
+
+from .general_checkpoint_io import GeneralCheckpointIO
+from .index_file import CheckpointIndexFile
+from .utils import (
+ StateDictSharder,
+ gather_distributed_param,
+ get_model_base_filenames,
+ get_optimizer_base_filenames,
+ is_safetensors_available,
+ load_shard_state_dict,
+ load_state_dict,
+ load_state_dict_into_model,
+ load_states_into_optimizer,
+ save_config_file,
+ save_param_groups,
+ save_state_dict,
+ save_state_dict_shards,
+ search_tp_partition_dim,
+ sharded_optimizer_loading_epilogue,
+)
+
+try:
+ from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
+except ImportError:
+ _EXTRA_STATE_KEY_SUFFIX = "_extra_state"
+
+
+class HybridParallelCheckpointIO(GeneralCheckpointIO):
+ """
+ CheckpointIO for Hybrid Parallel Training.
+
+ Args:
+ dp_group (ProcessGroup): Process group along data parallel dimension.
+ pp_group (ProcessGroup): Process group along pipeline parallel dimension.
+ tp_group (ProcessGroup): Process group along tensor parallel dimension.
+ zero_stage (int): The zero stage of plugin. Should be in [0, 1, 2].
+ verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True.
+ """
+
+ def __init__(
+ self,
+ dp_group: ProcessGroup,
+ pp_group: ProcessGroup,
+ tp_group: ProcessGroup,
+ zero_stage: int,
+ verbose: bool = True,
+ ) -> None:
+ super().__init__()
+ self.dp_group = dp_group
+ self.pp_group = pp_group
+ self.tp_group = tp_group
+ self.dp_rank = dist.get_rank(self.dp_group)
+ self.tp_rank = dist.get_rank(self.tp_group)
+ self.pp_rank = dist.get_rank(self.pp_group)
+ self.dp_size = dist.get_world_size(dp_group)
+ self.pp_size = dist.get_world_size(pp_group)
+ self.tp_size = dist.get_world_size(tp_group)
+ self.use_zero = zero_stage > 0
+ self.verbose = verbose
+ self.coordinator = DistCoordinator()
+
+ @staticmethod
+ def _model_sharder(
+ model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024
+ ) -> Iterator[Tuple[OrderedDict, int]]:
+ # An internel method that breaks state_dict of model into shards within limited size.
+
+ state_dict_sharder = StateDictSharder(size_per_shard)
+
+ # Save parameters.
+ for name, param in model.named_parameters():
+ if param is None:
+ continue
+ # Gather tensor pieces when using tensor parallel.
+ param_ = gather_distributed_param(param, keep_vars=False)
+ block, block_size = state_dict_sharder.append_param(prefix + name, param_)
+ if block is not None:
+ yield block, block_size
+
+ # Save buffers.
+ for name, buf in model.named_buffers():
+ if buf is not None and name not in model._non_persistent_buffers_set:
+ buffer = buf if keep_vars else buf.detach()
+ block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
+ if block is not None:
+ yield block, block_size
+
+ # Save extra states.
+ extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
+ if (
+ getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
+ is not torch.nn.Module.get_extra_state
+ ):
+ extra_state = model.get_extra_state()
+ block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
+ if block is not None:
+ yield block, block_size
+
+ # Return the last block in sharder.
+ yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
+
+ @staticmethod
+ def _optimizer_sharder(
+ optimizer: OptimizerWrapper,
+ use_zero: bool,
+ dp_group: ProcessGroup,
+ tp_group: ProcessGroup,
+ size_per_shard: int = 1024,
+ ):
+ # An internel method that breaks state_dict of optimizer into shards within limited size.
+
+ state_dict_sharder = StateDictSharder(size_per_shard)
+ param_info = optimizer.param_info
+ master_to_working_map = optimizer.get_master_to_working_map()
+
+ for param, state in optimizer.optim.state.items():
+ if param is None:
+ continue
+
+ if master_to_working_map is not None:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+
+ param_id = param_info["param2id"][id(working_param)]
+ original_shape = param_info["param2shape"][id(working_param)]
+ state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
+ state,
+ working_param,
+ original_shape=original_shape,
+ dp_group=dp_group,
+ tp_group=tp_group,
+ use_zero=use_zero,
+ inplace=False,
+ )
+
+ block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
+ if block is not None:
+ yield block, block_size
+
+ # Return the last block in sharder.
+ yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
+
+ def save_sharded_model(
+ self,
+ model: ModelWrapper,
+ checkpoint: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024,
+ use_safetensors: bool = False,
+ ) -> None:
+ """
+ Save sharded model checkpoint under the given checkpointing path.
+ The following files will be created under the path:
+ - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
+ - Multiple files that store state tensors of models.
+ If pipeline parallelism is used, the filenames are in the form of "pytorch_model.-stage-000XX-shard-000XX.bin".
+ If pipeline parallelism is not used, "pytorch_model.-000XX.bin"
+
+
+ Args:
+ model (nn.Module): Model on local device to be saved.
+ checkpoint (str): Checkpointing path which should be a directory path.
+ gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
+ prefix (str, optional): Perfix of file to save. Defaults to None.
+ size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
+ use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
+ """
+
+ assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
+ model = model.unwrap()
+
+ if os.path.isfile(checkpoint):
+ logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
+ return
+
+ Path(checkpoint).mkdir(parents=True, exist_ok=True)
+
+ # Devices along the same dp_group share the same copies of model.
+ # So only let the device with dp_rank == 0 save the model.
+ if self.dp_rank != 0:
+ return
+
+ # Then collect the sharded parameters & buffers along tp_group.
+ # Only devices with tp_rank == 0 are responsible for model saving.
+ state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
+ weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
+ index_file = CheckpointIndexFile(checkpoint)
+ control_saving = self.tp_rank == 0
+
+ if self.pp_size == 1:
+ # When pipeline is not used, save the model shards as in general checkpointIO
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=control_saving,
+ use_safetensors=use_safetensors,
+ )
+ if control_saving:
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ save_config_file(model, checkpoint)
+ if self.verbose and self.coordinator.is_master():
+ logging.info(
+ f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
+
+ else:
+ # When pipeline is used, each stage produces its own shard files and index files.
+ # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
+ # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
+
+ final_index_file_path = copy.deepcopy(save_index_file)
+ tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
+ Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
+
+ # Manage filenames of sharded weights and index file for each pipeline stage.
+ weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
+ weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
+ save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
+ save_index_file = os.path.join("tmp_index_files", save_index_file)
+
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=control_saving,
+ use_safetensors=use_safetensors,
+ use_pp_format=True,
+ )
+ if control_saving:
+ assert (
+ self.dp_rank == 0 and self.tp_rank == 0
+ ), "The saving process should have both dp_rank and tp_rank as 0."
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ else:
+ return
+
+ dist.barrier(self.pp_group)
+
+ # The global master rank integrates the index files and clean the folder.
+ if self.pp_rank == 0:
+ final_index_file = CheckpointIndexFile(checkpoint)
+ final_index_file.append_meta_data("total_size", 0)
+
+ for filename in os.listdir(tmp_index_file_folder):
+ stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
+ final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
+ for weight, weight_filename in stage_index_file.weight_map.items():
+ final_index_file.append_weight_map(weight, weight_filename)
+
+ final_index_file.write_index_file(final_index_file_path)
+ save_config_file(model, checkpoint)
+ rmtree(tmp_index_file_folder)
+ if self.verbose and self.coordinator.is_master():
+ logging.info(
+ f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {final_index_file_path}."
+ )
+
+ def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False):
+ """
+ Load sharded model with the given path to index file of checkpoint folder.
+
+ Args:
+ model (nn.Module): The model to be loaded.
+ checkpoint_index_file (str): Path to the index file of checkpointing folder.
+ strict (bool, optional): For name matching during loading state_dict. Defaults to False.
+ This argument should be manually set to False since params on same device might be stored in different files.
+ """
+ assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
+ model_before_wrapping = model # backup for model before wrapping
+ model = model.unwrap()
+
+ # Check whether the checkpoint uses safetensors.
+ use_safetensors = False
+ if "safetensors" in checkpoint_index_file.name:
+ use_safetensors = True
+
+ if use_safetensors and not is_safetensors_available():
+ raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
+
+ # Read checkpoint index file.
+ ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
+ ckpt_root_path = ckpt_index_file.root_path
+ weight_map = ckpt_index_file.weight_map
+ strict = False
+
+ # Load params & buffers to model.
+ # Keep a record of loaded files so that file will not be repeatedly loaded.
+ loaded_file = set()
+
+ def _load(name: str):
+ if name not in weight_map:
+ raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
+ filename = weight_map[name]
+
+ # If this param/buffer has been loaded before, directly return.
+ if filename in loaded_file:
+ return
+
+ file_path = os.path.join(ckpt_root_path, filename)
+ state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
+ missing_keys = []
+
+ load_state_dict_into_model(
+ model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
+ )
+ loaded_file.add(filename)
+
+ # Load parameters.
+ for name, _ in model.named_parameters():
+ _load(name)
+
+ # Load buffers.
+ non_persistent_buffers = set()
+ for n, m in model.named_modules():
+ non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set)
+ for name, buf in model.named_buffers():
+ if buf is not None and name not in non_persistent_buffers:
+ _load(name)
+
+ # Load extra states.
+ extra_state_key = _EXTRA_STATE_KEY_SUFFIX
+ if (
+ getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
+ is not torch.nn.Module.get_extra_state
+ ):
+ _load(extra_state_key)
+
+ # Update master params if mixed-precision training is enabled.
+ model_before_wrapping.update_master_params()
+
+ if self.verbose and self.coordinator.is_master():
+ logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
+
+ def save_sharded_optimizer(
+ self,
+ optimizer: OptimizerWrapper,
+ checkpoint: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024,
+ ):
+ """
+ Save sharded optimizer checkpoint under the given checkpointing path.
+ The following files will be created under the path:
+ - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
+ - A group file (pytorch_optim_group.bin) recording information of param_groups
+ - Multiple files that store state tensors of optimizers.
+ If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.-stage-000XX-shard-000XX.bin".
+ If pipeline parallelism is not used, "pytorch_optim.-000XX.bin"
+
+ Args:
+ optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
+ checkpoint (str): Path to save optimizer state_dict
+ gather_dtensor (bool): Whether to gather_dtensor, not used
+ prefix (str): Perfix of file to save
+ size_per_shard (int): Max file size of each file shard that store state tensors
+ """
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
+ if os.path.isfile(checkpoint):
+ logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
+ return
+
+ Path(checkpoint).mkdir(parents=True, exist_ok=True)
+
+ # Devices along the same dp_group share the same copies of states when zero is not used.
+ # In this case only let the device with dp_rank == 0 save the model.
+ if not self.use_zero and self.dp_rank != 0:
+ return
+
+ # Then collect the sharded states along dp_group(if using zero)/tp_group.
+ # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
+ state_dict_shard = HybridParallelCheckpointIO._optimizer_sharder(
+ optimizer,
+ use_zero=self.use_zero,
+ dp_group=self.dp_group,
+ tp_group=self.tp_group,
+ size_per_shard=size_per_shard,
+ )
+ states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
+ index_file = CheckpointIndexFile(checkpoint)
+ control_saving = self.dp_rank == 0 and self.tp_rank == 0
+
+ if self.pp_size == 1:
+ # When pipeline is not used, save the optimizer shards as in general checkpointIO
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=control_saving,
+ )
+
+ if control_saving:
+ # Store param groups.
+ index_file.append_meta_data("param_groups", param_group_file)
+ group_file_path = os.path.join(checkpoint, param_group_file)
+ save_param_groups(optimizer.param_info, group_file_path)
+ # Store index file.
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ if self.verbose and self.coordinator.is_master():
+ logging.info(
+ f"The optimizer is going to be split to checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
+
+ else:
+ # When pipeline is used, each stage produces its own shard files and index files.
+ # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
+ # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
+
+ final_index_file_path = copy.deepcopy(save_index_file)
+ tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
+ Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
+
+ # Manage filenames of sharded weights and index file for each pipeline stage.
+ states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
+ save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
+ save_index_file = os.path.join("tmp_index_files", save_index_file)
+
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=control_saving,
+ use_pp_format=True,
+ )
+
+ if control_saving:
+ assert (
+ self.dp_rank == 0 and self.tp_rank == 0
+ ), "The saving process should have both dp_rank and tp_rank as 0."
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ else:
+ return
+
+ dist.barrier(self.pp_group)
+
+ # The global master rank integrates the index files and clean the folder.
+ if self.pp_rank == 0:
+ final_index_file = CheckpointIndexFile(checkpoint)
+ final_index_file.append_meta_data("total_size", 0)
+
+ for filename in os.listdir(tmp_index_file_folder):
+ stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
+ final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
+ for param_id, state_filename in stage_index_file.weight_map.items():
+ final_index_file.append_weight_map(param_id, state_filename)
+
+ # Store param groups.
+ final_index_file.append_meta_data("param_groups", param_group_file)
+ group_file_path = os.path.join(checkpoint, param_group_file)
+ save_param_groups(optimizer.param_info, group_file_path)
+
+ final_index_file.write_index_file(final_index_file_path)
+ rmtree(tmp_index_file_folder)
+
+ if self.verbose and self.coordinator.is_master():
+ logging.info(
+ f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {final_index_file_path}."
+ )
+
+ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
+ """
+ Load sharded optimizer with the given path to index file of checkpoint folder.
+
+ Args:
+ optimizer (OptimizerWrapper): The optimizer to be loaded.
+ checkpoint_index_file (str): Path to the index file of checkpointing folder.
+ prefix (str): Not used.
+ """
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
+
+ def _get_param_id_from_optimizer_param(
+ param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
+ ):
+ if master_to_working_map is not None:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+ return optimizer.param_info["param2id"][id(working_param)]
+
+ # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
+ # When Zero is used, the mapped parameter objects should be fp32 master parameters.
+ # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
+ id_map = {}
+ master_to_working_map = optimizer.get_master_to_working_map()
+ for pg in optimizer.optim.param_groups:
+ for param in pg["params"]:
+ param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
+ id_map[param_id] = param
+
+ # Read checkpoint index file.
+ ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
+ ckpt_root_path = ckpt_index_file.root_path
+ weight_map = ckpt_index_file.weight_map
+ weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
+
+ # Load param_groups
+ param_group_path = ckpt_index_file.get_param_group_filename()
+ if param_group_path is None:
+ raise RuntimeError(
+ f"Invalid index file path {checkpoint_index_file} for an optimizer. \
+ Lacking param group file under current directory."
+ )
+ saved_groups = torch.load(param_group_path)
+
+ updated_groups = []
+ for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
+ # obtain updated param group
+ new_pg = copy.deepcopy(saved_pg)
+ new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
+ updated_groups.append(new_pg)
+ optimizer.optim.__dict__.update({"param_groups": updated_groups})
+
+ # Load saved states to optimizer.
+ # Keep a record of loaded files so that file will not be repeatedly loaded.
+ loaded_file = set()
+ for pg in optimizer.optim.param_groups:
+ for param in pg["params"]:
+ if param is None:
+ continue
+ param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
+ if param_id not in weight_map:
+ continue
+ filename = weight_map[param_id]
+
+ # If this param's states has been loaded before, directly return.
+ if filename in loaded_file:
+ continue
+
+ file_path = os.path.join(ckpt_root_path, filename)
+ state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
+ load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
+ loaded_file.add(filename)
+
+ # Then shard the loaded optimizer states if using tp/zero.
+ for param, state in optimizer.optim.state.items():
+ device = param.device
+ if master_to_working_map is not None:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+ original_shape = optimizer.param_info["param2shape"][id(working_param)]
+ sharded_state = self.shard_from_complete_optimizer_state(
+ state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
+ )
+ optimizer.optim.state[param] = sharded_state
+
+ sharded_optimizer_loading_epilogue(optimizer.optim)
+ if self.verbose and self.coordinator.is_master():
+ logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
+
+ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
+ """
+ Save model state dict to a single file with given checkpointing path.
+
+ Args:
+ model (nn.Module): Model on local device to be saved.
+ checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path.
+ gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True.
+ use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
+ """
+ if self.coordinator.is_master():
+ logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
+
+ assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
+ model = model.unwrap()
+
+ if self.dp_rank != 0:
+ return
+
+ # The logic of collecting parameter shards along tp degree
+ # has been implemented by _save_to_state_dict method of ParallelModule in Shardformer.
+ state_dict = model.state_dict()
+
+ if self.pp_size == 1:
+ # When pipeline is not used, let master rank directly save the collected state_dict.
+ if self.tp_rank == 0:
+ save_state_dict(state_dict, checkpoint, use_safetensors)
+ else:
+ # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
+ state_dict_list = [None for _ in range(self.pp_size)]
+ dist.barrier(self.pp_group)
+ dist.all_gather_object(state_dict_list, state_dict, self.pp_group)
+
+ # Only the master rank do the saving.
+ if self.coordinator.is_master():
+ complete_state_dict = dict()
+ for _state_dict in state_dict_list:
+ complete_state_dict.update(_state_dict)
+ save_state_dict(complete_state_dict, checkpoint, use_safetensors)
+
+ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False):
+ """
+ Load model from a single file with the given path of checkpoint.
+
+ Args:
+ model (nn.Module): The model to be loaded.
+ checkpoint_index_file (str): Path to the checkpoint file.
+ strict (bool, optional): For name matching during loading state_dict. Defaults to False.
+ This argument should be manually set to False since not all params in checkpoint are needed for each device when pipeline is enabled.
+ """
+ if self.coordinator.is_master():
+ logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
+
+ assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
+ strict = False
+ model_before_wrapping = model
+ model = model.unwrap()
+
+ # Load from checkpoint. Since the logic of breaking parameter shards along tp degree
+ # has been implemented by _load_from_state_dict method of ParallelModule in Shardformer,
+ # model.load_state_dict can be directly called.
+ state_dict = load_state_dict(checkpoint)
+ model.load_state_dict(state_dict, strict=strict)
+
+ # Update master params if mixed-precision training is enabled.
+ model_before_wrapping.update_master_params()
+
+ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
+ """
+ Save optimizer state dict to a file with given path.
+
+ Args:
+ optimizer (OptimizerWrapper): Optimizer to save sharded state_dict.
+ checkpoint (str): Path to save optimizer state_dict.
+ gather_dtensor (bool): Whether to gather_dtensor, not used.
+ """
+ if self.coordinator.is_master():
+ logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
+
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
+
+ # optimizer states of parameters kept by local device('s pipeline stage)
+ local_states = dict()
+
+ for param, state in optimizer.optim.state.items():
+ if param is None:
+ continue
+
+ # working param is needed for obtaining correct param_id
+ master_to_working_map = optimizer.get_master_to_working_map()
+ if master_to_working_map is not None:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+
+ # gather complete state from tp shards & dp shards
+ param_id = optimizer.param_info["param2id"][id(working_param)]
+ original_shape = optimizer.param_info["param2shape"][id(working_param)]
+ local_states[param_id] = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
+ state,
+ working_param,
+ original_shape=original_shape,
+ dp_group=self.dp_group,
+ tp_group=self.tp_group,
+ use_zero=self.use_zero,
+ inplace=False,
+ device=torch.device("cuda"),
+ )
+
+ if self.pp_size == 1:
+ # When pipeline is not used, let master rank directly save the collected state_dict.
+ state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": local_states}
+ if self.coordinator.is_master():
+ save_state_dict(state_dict, checkpoint, use_safetensors=False)
+ else:
+ # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
+ states_list = [None for _ in range(self.pp_size)]
+ dist.barrier(self.pp_group)
+ dist.all_gather_object(states_list, local_states, self.pp_group)
+
+ # Only the master rank do the saving.
+ if self.coordinator.is_master():
+ state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": dict()}
+ for _states in states_list:
+ state_dict["state"].update(_states)
+ save_state_dict(state_dict, checkpoint, use_safetensors=False)
+
+ def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
+ """
+ Load optimizer from a file with given path.
+
+ Args:
+ optimizer (OptimizerWrapper): The optimizer to be loaded.
+ checkpoint_index_file (str): Path to the checkpoint file.
+ """
+
+ def _get_param_id_from_optimizer_param(
+ param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
+ ):
+ if master_to_working_map is not None:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+ return optimizer.param_info["param2id"][id(working_param)]
+
+ if self.coordinator.is_master():
+ logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
+
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
+
+ # Complete optimizer state_dict loaded from checkpoint, need to be processed later.
+ state_dict = load_state_dict(checkpoint)
+
+ # Load param_groups.
+ updated_groups = []
+ saved_groups = state_dict["param_groups"]
+ for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
+ new_pg = copy.deepcopy(saved_pg)
+ new_pg["params"] = old_pg["params"] # Only keep the parameters kept by current pipeline stage.
+ updated_groups.append(new_pg)
+ optimizer.optim.__dict__.update({"param_groups": updated_groups})
+
+ # Load saved states to optimizer. First discard those states not belonging to current pipeline stage.
+ master_to_working_map = optimizer.get_master_to_working_map()
+ id_map = {}
+ for pg in optimizer.optim.param_groups:
+ for param in pg["params"]:
+ param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
+ id_map[param_id] = param
+ load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True)
+
+ # Then shard the loaded optimizer states if using tp/zero.
+ for param, state in optimizer.optim.state.items():
+ if param is None:
+ continue
+ device = param.device
+ if master_to_working_map is not None:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+ original_shape = optimizer.param_info["param2shape"][id(working_param)]
+ sharded_state = self.shard_from_complete_optimizer_state(
+ state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
+ )
+ optimizer.optim.state[param] = sharded_state
+
+ sharded_optimizer_loading_epilogue(optimizer.optim)
+
+ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
+ """
+ Save lr scheduler to checkpoint but only on master process.
+ """
+ if self.coordinator.is_master():
+ super().save_lr_scheduler(lr_scheduler, checkpoint)
+
+ @staticmethod
+ def gather_from_sharded_optimizer_state(
+ state: OrderedDict,
+ param: torch.Tensor,
+ original_shape: torch.Size,
+ dp_group: ProcessGroup,
+ tp_group: ProcessGroup,
+ use_zero: bool,
+ inplace: bool,
+ device: torch.device = torch.device("cpu"),
+ ) -> OrderedDict:
+ """
+ With given parameter and its optimizer states, gather the complete optimizer state for saving.
+
+ Args:
+ state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
+ param (torch.Tensor): The given parameter. It should be working_param when using Zero.
+ original_shape (torch.Size): The size of parameter before sharding.
+ dp_group (ProcessGroup): The process group of data parallel.
+ tp_group (ProcessGroup): The process group of tensor parallel.
+ use_zero (bool): Whether Zero is used.
+ inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
+ device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu').
+
+ Returns:
+ OrderedDict: The complete optimizer state of given parameter.
+ """
+ dp_size = dist.get_world_size(dp_group)
+ tp_size = dist.get_world_size(tp_group)
+ current_shape = param.shape
+ state_ = state if inplace else copy.deepcopy(state)
+
+ for k, v in state_.items():
+ if isinstance(v, torch.Tensor) and k != "step":
+ # First gather Zero shards.
+ if use_zero:
+ v = v.cuda()
+ gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
+ dist.all_gather(gather_tensor, v, group=dp_group)
+ v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
+
+ # Then gather TP shards.
+ partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
+ if partition_dim is not None:
+ gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)]
+ dist.all_gather(gather_tensor, v, group=tp_group)
+ v = torch.cat(gather_tensor, dim=partition_dim)
+
+ state_[k] = v.detach().clone().to(device)
+
+ return state_
+
+ def shard_from_complete_optimizer_state(
+ self,
+ state: OrderedDict,
+ current_shape: torch.Size,
+ original_shape: torch.Size,
+ device: torch.device,
+ inplace: bool,
+ ) -> OrderedDict:
+ """
+ With complete optimizer states of a specific parameter loaded from checkpoint,
+ slice out the sharded optimizer states kept by current device.
+
+ Args:
+ state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint.
+ current_shape (torch.Size): The size of parameter after sharding.
+ original_shape (torch.Size): The size of parameter before sharding.
+ device (torch.device): The destination device of loaded optimizer states.
+ inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
+
+ Returns:
+ OrderedDict: The sharded optimizer state of the given parameter.
+ """
+ state_ = state if inplace else copy.deepcopy(state)
+
+ for k, v in state_.items():
+ if isinstance(v, torch.Tensor) and k != "step":
+ # Shard state along tensor parallel group.
+ partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
+ if partition_dim is not None:
+ slice_size = current_shape[partition_dim]
+ v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
+
+ # Shard state along data parallel group when using Zero.
+ if self.use_zero:
+ padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
+ with torch.no_grad():
+ v = v.flatten()
+ if padding_size > 0:
+ v = torch.nn.functional.pad(v, [0, padding_size])
+ slice_size = v.numel() // self.dp_size
+ v = v.split(slice_size, dim=0)[self.dp_rank]
+
+ state_[k] = v.detach().clone().to(device)
+
+ return state_
diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py
index 89224787a91b9824f2261aece909f6a1cb094a17..da12c146f2c309315e3e00aa85f2b72fcee2e723 100644
--- a/colossalai/checkpoint_io/index_file.py
+++ b/colossalai/checkpoint_io/index_file.py
@@ -1,10 +1,12 @@
import json
+import os
+from collections import OrderedDict
from pathlib import Path
-from typing import Any, List, Union
+from typing import Any, Dict, List, Union
from .utils import is_dtensor_checkpoint
-__all__ = ['CheckpointIndexFile']
+__all__ = ["CheckpointIndexFile"]
class CheckpointIndexFile:
@@ -18,10 +20,12 @@ class CheckpointIndexFile:
>>> index.export('new_index.json')
"""
- def __init__(self) -> None:
- self.root_path = None
- self.metadata: dict = dict()
- self.weight_map: dict = dict()
+ def __init__(self, root_path=None) -> None:
+ self.root_path = root_path
+
+ # use ordered dict to preserve the tensor checkpoint order
+ self.metadata: Dict = OrderedDict()
+ self.weight_map: Dict = OrderedDict()
@staticmethod
def from_file(index_path: Union[str, Path]):
@@ -46,7 +50,7 @@ class CheckpointIndexFile:
json_path (str): path to the json file.
"""
# load the json file
- with open(json_path, 'r') as f:
+ with open(json_path, "r") as f:
index = json.load(f)
# assign attributes if exists
@@ -71,7 +75,7 @@ class CheckpointIndexFile:
index["weight_map"] = self.weight_map
# export the index file
- with open(json_path, 'w') as f:
+ with open(json_path, "w") as f:
json.dump(index, f, indent=4)
def append_weight_map(self, param_name: str, shard_file: str):
@@ -107,7 +111,7 @@ class CheckpointIndexFile:
return True
return False
- def get_checkpoint_fileanames(self) -> List[str]:
+ def get_checkpoint_filenames(self) -> List[str]:
"""
Get the set of checkpoint filenames in the weight map.
@@ -148,9 +152,31 @@ class CheckpointIndexFile:
"""
ckpt_path = self.weight_map[param_name]
return ckpt_path
-
+
def get_all_param_names(self):
"""
Get all the weight keys.
"""
return list(self.weight_map.keys())
+
+ def get_param_group_filename(self) -> Union[str, None]:
+ """
+ Get the file name of param_group file if this is a checkpoint for optimizer.
+ Returns:
+ str: param_group file name
+ """
+ filename = self.metadata.get("param_groups", None)
+ if filename:
+ return str(self.root_path.joinpath(filename))
+ else:
+ return None
+
+ def write_index_file(self, save_index_file):
+ """
+ Write index file.
+ """
+ save_index_file = os.path.join(self.root_path, save_index_file)
+ index = {"metadata": self.metadata, "weight_map": self.weight_map}
+ with open(save_index_file, "w", encoding="utf-8") as f:
+ content = json.dumps(index, indent=2) + "\n"
+ f.write(content)
diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py
index 37d22d08df40eaaa209ad32b097ce735a14378dc..06dab1fdb72a4d995ebd7d0900bb52afda0d2898 100644
--- a/colossalai/checkpoint_io/utils.py
+++ b/colossalai/checkpoint_io/utils.py
@@ -1,33 +1,51 @@
# coding=utf-8
+import os
+import re
+from collections import abc as container_abcs
+from collections import defaultdict
+from itertools import chain
from pathlib import Path
+from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
+
import torch
import torch.nn as nn
-from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple
-from colossalai.tensor.d_tensor.d_tensor import DTensor
-import re
+from packaging.version import Version
+from torch.optim import Optimizer
+
+from colossalai.tensor.d_tensor import (
+ is_customized_distributed_tensor,
+ is_distributed_tensor,
+ to_global,
+ to_global_for_customized_distributed_tensor,
+)
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
+STATES_NAME = "pytorch_optim.bin"
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
+STATES_INDEX_NAME = "pytorch_optim.bin.index.json"
+GROUP_FILE_NAME = "pytorch_optim_group.bin"
# ======================================
# General helper functions
# ======================================
+
def calculate_tensor_size(tensor: torch.Tensor) -> float:
"""
Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size.
If so, a new shard should be created.
Args:
- tenosr (torch.Tensor): the tensor to calculate size for.
+ tensor (torch.Tensor): the tensor to calculate size for.
Returns:
float: size of the tensor in MB.
"""
return tensor.numel() * tensor.element_size() / 1024 / 1024
+
def is_safetensors_available() -> bool:
"""
Check whether safetensors is available.
@@ -36,7 +54,6 @@ def is_safetensors_available() -> bool:
bool: whether safetensors is available.
"""
try:
- import safetensors
return True
except ImportError:
return False
@@ -52,7 +69,7 @@ def is_dtensor_checkpoint(checkpoint_file_path: str) -> bool:
Returns:
bool: whether the checkpoint file is a dtensor checkpoint.
"""
- if checkpoint_file_path.endswith('.*.safetensors') or checkpoint_file_path.endswith('.*.bin'):
+ if checkpoint_file_path.endswith(".*.safetensors") or checkpoint_file_path.endswith(".*.bin"):
return True
else:
return False
@@ -68,136 +85,210 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
Returns:
bool: whether the checkpoint file is a safetensor checkpoint.
"""
- if checkpoint_file_path.endswith('.safetensors'):
+ if checkpoint_file_path.endswith(".safetensors"):
return True
else:
return False
+def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Size, tp_size: int) -> Optional[int]:
+ """
+ Given the current shape of parameter and the shape of parameter before sharding,
+ return the dimension along which the parameter is sharded when using tensor parallel.
+ If tensor parallel is not used, return None.
+
+ Args:
+ current_shape (torch.Size): The current shape of parameter after sharding.
+ original_shape (torch.Size): The shape of parameter before sharding.
+ tp_size (int): The size of tp group.
+
+ Returns:
+ Optional[int]: The dimension along which parameter is partitioned.
+ """
+ partition_dim = None
+ for dim, length in enumerate(original_shape):
+ if length > current_shape[dim]:
+ partition_dim = dim
+ break
+ if partition_dim is not None:
+ assert (
+ original_shape[partition_dim] == tp_size * current_shape[partition_dim]
+ ), f"The parameter isn't evenly distributed among tensor parallel group: \
+ shape before sharding {original_shape}, shape after sharding {current_shape}"
+
+ return partition_dim
+
+
# ======================================
-# Helper functions for saving shard file
+# Helper classes and functions for saving shard file
# ======================================
-def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, weights_name: str = WEIGHTS_NAME):
-
+
+
+class StateDictSharder:
+ def __init__(self, size_per_shard: int) -> None:
+ self.max_shard_size = size_per_shard
+ self.current_block = OrderedDict()
+ self.current_block_size = 0
+
+ def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
+ tensor_size = calculate_tensor_size(tensor)
+ ret_block = None
+ ret_block_size = 0
+
+ # before we return the current block and create a new block,
+ # we need to ensure that the current block is not empty
+ if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0:
+ ret_block = self.current_block
+ ret_block_size = self.current_block_size
+ self.current_block = OrderedDict()
+ self.current_block_size = 0
+
+ self.current_block[name] = tensor
+ self.current_block_size += tensor_size
+ return ret_block, ret_block_size
+
+ def append_optim_state(self, param_id: int, state: OrderedDict) -> Tuple[Optional[OrderedDict], int]:
+ # A state might contain more than one tensors.
+ # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
+ state_size = 0
+ isDTensor = False
+ for state_tensor in state.values():
+ # When state_tensor is not of Tensor class,
+ # e.g., a SGD optimizer with momentum set to 0 can have None as state
+ # The calculation of tensor size should be skipped to avoid error.
+ if not isinstance(state_tensor, torch.Tensor):
+ continue
+
+ # If the states are stored as DTensors, mark isDTensor as true.
+ if is_distributed_tensor(state_tensor):
+ isDTensor = True
+ state_size += calculate_tensor_size(state_tensor)
+
+ ret_block = None
+ ret_block_size = 0
+
+ # directly return if state is stored as distributed tensor
+ if isDTensor:
+ return ret_block, ret_block_size
+
+ # before we return the current block and create a new block,
+ # we need to ensure that the current block is not empty
+ if self.current_block_size + state_size > self.max_shard_size and self.current_block_size > 0:
+ ret_block = self.current_block
+ ret_block_size = self.current_block_size
+ self.current_block = OrderedDict()
+ self.current_block_size = 0
+
+ self.current_block[param_id] = state
+ self.current_block_size += state_size
+ return ret_block, ret_block_size
+
+
+def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> torch.Tensor:
"""
- Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
- given size.
+ Gather the complete parameter for saving if passed in param is distributed under tp setting.
+
+ Args:
+ param (torch.Tensor): A model parameter, might be d_tensor.
+ keep_vars (bool, optional): Whether to return the parameter in calculation graph. Defaults to False.
+
+ Returns:
+ torch.Tensor: the complete parameter
"""
- sharded_state_dicts = []
- current_block = {}
- current_block_size = 0
- total_size = 0
+ param_ = param if keep_vars else param.detach()
+ if is_distributed_tensor(param_):
+ return to_global(param_)
+ elif is_customized_distributed_tensor(param_):
+ return to_global_for_customized_distributed_tensor(param_)
+ else:
+ return param_
+
+
+def save_state_dict_shards(
+ sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
+ checkpoint: str,
+ index_file: "CheckpointIndexFile",
+ base_filename: str,
+ is_master: bool,
+ use_safetensors: bool = False,
+ use_pp_format: bool = False,
+) -> int:
+ """
+ Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
+ Args:
+ sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
+ checkpoint (str): The path of checkpoint directory as string.
+ index_file (CheckpointIndexFile): The index file object to be updated.
+ base_filename (str): Decides the prefix of filenames of shards.
+ is_master (bool): Whether current rank is main process.
+ use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False.
+ use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
- for key, weight in state_dict.items():
- if type(weight) != DTensor:
- weight_size = calculate_tensor_size(weight)
-
- # If this weight is going to tip up over the maximal size, we split.
- if current_block_size + weight_size > max_shard_size:
- sharded_state_dicts.append(current_block)
- current_block = {}
- current_block_size = 0
-
- current_block[key] = weight
- current_block_size += weight_size
- total_size += weight_size
-
- # Add the last block
- sharded_state_dicts.append(current_block)
-
- # If we only have one shard, we return it
- if len(sharded_state_dicts) == 1:
- return {weights_name: sharded_state_dicts[0]}, None
-
- # Otherwise, let's build the index
- weight_map = {}
- shards = {}
-
- for idx, shard in enumerate(sharded_state_dicts):
- shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
- shard_file = shard_file.replace(
- ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
- )
- shards[shard_file] = shard
+ Returns:
+ int: the total size of shards
+ """
+
+ total_size = 0
+ shard_filenames = []
+ for idx, shard_pair in enumerate(sharded_state_dict):
+ shard, current_size = shard_pair
+ if not is_master:
+ del shard
+ continue
+ shard_file = get_shard_filename(base_filename, idx)
+ total_size = total_size + current_size
for key in shard.keys():
- weight_map[key] = shard_file
+ index_file.append_weight_map(key, shard_file)
+ checkpoint_file_path = os.path.join(checkpoint, shard_file)
- # Add the metadata
- metadata = {"total_size": total_size}
- index = {"metadata": metadata, "weight_map": weight_map}
- return shards, index
+ # Only save on master rank.
+ save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors)
+ shard_filenames.append(shard_file)
+ del shard
-def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
+ # Clean folder, deleted unneeded files.
+ clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format)
+
+ return total_size
+
+
+def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
"""
- load shard state dict into model
+ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
+ given size.
"""
- if use_safetensors and not checkpoint_file.suffix == ".safetensors":
- raise Exception("load the model using `safetensors`, but no file endwith .safetensors")
- if use_safetensors:
- from safetensors.torch import safe_open
- from safetensors.torch import load_file as safe_load_file
- with safe_open(checkpoint_file, framework="pt") as f:
- metadata = f.metadata()
- if metadata["format"] != "pt":
- raise NotImplementedError(
- f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
- )
- return safe_load_file(checkpoint_file)
- else:
- return torch.load(checkpoint_file)
-
-def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False):
- r"""Copies parameters and buffers from :attr:`state_dict` into
- this module and its descendants.
+ state_dict_sharder = StateDictSharder(max_shard_size)
- Args:
- state_dict (dict): a dict containing parameters and
- persistent buffers.
- """
- if not isinstance(state_dict, Mapping):
- raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))
+ for key, weight in state_dict.items():
+ if not is_distributed_tensor(weight):
+ block, block_size = state_dict_sharder.append_param(key, weight)
- unexpected_keys: List[str] = []
- sub_missing_keys: List[str] = []
- error_msgs: List[str] = []
+ if block != None:
+ yield block, block_size
- # copy state_dict so _load_from_state_dict can modify it
- metadata = getattr(state_dict, '_metadata', None)
- state_dict = OrderedDict(state_dict)
- if metadata is not None:
- state_dict._metadata = metadata
+ # Return the last block in sharder.
+ yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
- def load(module: nn.Module, state_dict, prefix=""):
- local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
- args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
- # Parameters of module and children will start with prefix. We can exit early if there are none in this
- # state_dict
- if len([key for key in state_dict if key.startswith(prefix)]) > 0:
- module._load_from_state_dict(*args)
- for name, child in module._modules.items():
- if child is not None:
- load(child, state_dict, prefix + name + ".")
+def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
+ """
+ Splits an optimizer state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
+ given size.
+ """
- load(model, state_dict, "")
- del load
+ # Only split state_dict['state']; state_dict['param_group'] is not considered in this function.
+ states = state_dict["state"]
+ state_dict_sharder = StateDictSharder(max_shard_size)
+
+ for param_id, state in states.items():
+ block, block_size = state_dict_sharder.append_optim_state(param_id, state)
+ if block != None:
+ yield block, block_size
+
+ # Return the last block in sharder.
+ yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
- # deal with missing key
- if len(missing_keys) > 0:
- deleted_keys = []
- for key in missing_keys:
- if key not in sub_missing_keys:
- deleted_keys.append(key)
- for key in deleted_keys:
- missing_keys.remove(key)
- if strict:
- if len(unexpected_keys) > 0:
- error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(
- ', '.join('"{}"'.format(k) for k in unexpected_keys))
- raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
- model.__class__.__name__, "\n\t".join(error_msgs)))
-
# ======================================
# Helper functions for saving state dict
# ======================================
@@ -214,14 +305,99 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
"""
if use_safetensors:
assert is_safetensors_available(), "safetensors is not available."
- assert checkpoint_file_path.endswith('.safetensors'), \
- "safetensors only supports .safetensors suffix for checkpoint file."
+ assert checkpoint_file_path.endswith(
+ ".safetensors"
+ ), "safetensors only supports .safetensors suffix for checkpoint file."
from safetensors.torch import save_file as safe_save_file
+
safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"})
else:
torch.save(state_dict, checkpoint_file_path)
+def save_param_groups(state_dict: dict, group_file_path: str) -> None:
+ """
+ Save information of param_groups to given file path.
+
+ Args:
+ state_dict (dict): state dict.
+ group_file_path (str): path to the group file.
+ """
+ param_groups = state_dict["param_groups"]
+ torch.save(param_groups, group_file_path)
+
+
+def clean_folder(
+ checkpoint_path: str,
+ weights_name: str,
+ shard_filenames: List[str],
+ is_master: bool = True,
+ use_pp_format: bool = False,
+):
+ """
+ Clean the unneeded files in checkpoint directory after shards of state_dict have been saved.
+
+ Args:
+ checkpoint_path (str): Path to the checkpoint directory.
+ weights_name (str): Decides the prefix of filenames of weight shards.
+ shard_filenames (List[str]): The list of saved shard filenames which should not be removed.
+ is_master (bool, optional): Whether current rank is main process. Defaults to True.
+ use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
+
+ """
+ if is_master:
+ for filename in os.listdir(checkpoint_path):
+ full_filename = os.path.join(checkpoint_path, filename)
+ weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
+ filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
+ if not use_pp_format:
+ reg = re.compile(r"(.*?)-\d{5}")
+ else:
+ # When this checkpoint is created by pipeline parallel process, the pattern is a little different.
+ reg = re.compile(r"(.*?)-stage-\d{5}-shard-\d{5}")
+ if (
+ filename.startswith(weights_no_suffix)
+ and os.path.isfile(full_filename)
+ and filename not in shard_filenames
+ and reg.fullmatch(filename_no_suffix) is not None
+ ):
+ os.remove(full_filename)
+
+
+def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = True):
+ """
+ Save config.json/generation_config.json if model is a Huggingface pretrained model.
+ This method can only be called when a model is saved in a sharded way.
+
+ Args:
+ model (nn.Module): The model whose config should be saved if it's a huggingface model.
+ checkpoint_path (str): Path to the checkpoint directory.
+ is_master (bool): Whether current rank is main process.
+ """
+ try:
+ from transformers.modeling_utils import PreTrainedModel, get_parameter_dtype
+ from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model
+ except ImportError:
+ return
+ if not isinstance(model, PreTrainedModel):
+ return
+
+ model = unwrap_huggingface_model(model)
+
+ # save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
+ dtype = get_parameter_dtype(model)
+ model.config.torch_dtype = str(dtype).split(".")[1]
+
+ # Attach architecture to the config
+ model.config.architectures = [model.__class__.__name__]
+
+ # Save the config
+ if is_master:
+ model.config.save_pretrained(checkpoint_path)
+ if model.can_generate():
+ model.generation_config.save_pretrained(checkpoint_path)
+
+
def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None:
"""
Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
@@ -233,7 +409,7 @@ def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFi
size_per_shard (int): size per shard in MB.
"""
root_path = index_file.root_path
- output_root_path = root_path.joinpath('dtensor')
+ output_root_path = root_path.joinpath("dtensor")
# create directory
output_root_path.mkdir(exist_ok=True)
@@ -253,7 +429,7 @@ def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFi
# update the weight map
# * means all shards
- ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors)
+ ckpt_file_name_in_weight_map = "dtensor/" + generate_dtensor_file_name(name, "*", use_safetensors)
index_file.append_weight_map(name, ckpt_file_name_in_weight_map)
@@ -268,15 +444,14 @@ def get_checkpoint_file_suffix(use_safetensors: bool) -> str:
str: checkpoint file suffix.
"""
if use_safetensors:
- return '.safetensors'
+ return ".safetensors"
else:
- return '.bin'
+ return ".bin"
-def generate_checkpoint_shard_file_name(index: int,
- total_number: int,
- use_safetensors: bool,
- prefix: str = None) -> str:
+def generate_checkpoint_shard_file_name(
+ index: int, total_number: int, use_safetensors: bool, prefix: str = None
+) -> str:
"""
Generate checkpoint shard file name.
@@ -310,39 +485,190 @@ def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: boo
str: dtensor file name.
"""
suffix = get_checkpoint_file_suffix(use_safetensors)
- return f'{param_name}.{index}.{suffix}'
+ return f"{param_name}.{index}.{suffix}"
-def save_state_dict_as_shard(
- state_dict: dict,
- checkpoint_path: str,
- index: int,
- total_number: int,
- use_safetensors: bool,
- prefix: str = None,
-) -> None:
+# ========================================
+# Helper functions for loading state dict
+# ========================================
+
+
+def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
+ """
+ load shard state dict into model
"""
- Save state dict as shard.
+ if use_safetensors and not checkpoint_file.suffix == ".safetensors":
+ raise Exception("load the model using `safetensors`, but no file endwith .safetensors")
+ if use_safetensors:
+ from safetensors.torch import load_file as safe_load_file
+ from safetensors.torch import safe_open
+
+ with safe_open(checkpoint_file, framework="pt") as f:
+ metadata = f.metadata()
+ if metadata["format"] != "pt":
+ raise NotImplementedError(
+ f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
+ )
+ return safe_load_file(checkpoint_file)
+ else:
+ return torch.load(checkpoint_file, map_location=torch.device("cpu"))
+
+
+def load_state_dict_into_model(
+ model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True
+):
+ r"""Copies parameters and buffers from :attr:`state_dict` into
+ this module and its descendants.
Args:
- state_dict (dict): state dict.
- checkpoint_path (str): path to the checkpoint file.
- index (int): index of the shard.
- total_number (int): total number of shards.
- prefix (str): prefix of the shard file name.
- use_safetensors (bool): whether to use safetensors to save the checkpoint.
+ state_dict (dict): a dict containing parameters and
+ persistent buffers.
"""
- # generate the shard name
- shard_file_name = generate_checkpoint_shard_file_name(index, total_number, use_safetensors, prefix)
- shard_file_path = Path(checkpoint_path).joinpath(shard_file_name).absolute()
+ if not isinstance(state_dict, Mapping):
+ raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))
+
+ unexpected_keys: List[str] = []
+ sub_missing_keys: List[str] = []
+ error_msgs: List[str] = []
- # save the shard
- save_state_dict(state_dict, str(shard_file_path), use_safetensors)
+ # copy state_dict so _load_from_state_dict can modify it
+ metadata = getattr(state_dict, "_metadata", None)
+ state_dict = OrderedDict(state_dict)
+ if metadata is not None:
+ state_dict._metadata = metadata
+ def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True):
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
+ args = (state_dict, prefix, local_metadata, True, sub_missing_keys, [], error_msgs)
+ # Parameters of module and children will start with prefix. We can exit early if there are none in this
+ # state_dict
+ if len([key for key in state_dict if key.startswith(prefix)]) > 0:
+ module._load_from_state_dict(*args)
+ if load_sub_module:
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, state_dict, prefix + name + ".")
-# ========================================
-# Helper functions for loading state dict
-# ========================================
+ load(model, state_dict, "", load_sub_module)
+ del load
+
+ missing_keys = missing_keys.append(sub_missing_keys)
+
+ if strict:
+ if len(unexpected_keys) > 0:
+ error_msgs = "Unexpected key(s) in state_dict: {}. ".format(
+ ", ".join('"{}"'.format(k) for k in unexpected_keys)
+ )
+ raise RuntimeError(
+ "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
+ )
+
+
+def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str) -> dict:
+ """
+ Load information of param_groups into an initialized optimizer.
+ """
+
+ # Load list of param_groups from given file path.
+ # The params in saved_groups are in the form of integer indices.
+ saved_groups = torch.load(param_group_path, map_location=torch.device("cpu"))
+ if not isinstance(saved_groups, List):
+ raise ValueError(f"The param_groups saved at {param_group_path} is not of List type")
+
+ # The params in param_groups are in the form of pytorch tensors.
+ # For more details, please view source code of Optimizer class in pytorch.
+ param_groups = optimizer.param_groups
+
+ # Check the compatibility of saved_groups and param_groups.
+ if len(param_groups) != len(saved_groups):
+ raise ValueError("loaded state dict has a different number of original parameter groups")
+ param_lens = (len(g["params"]) for g in param_groups)
+ saved_lens = (len(g["params"]) for g in saved_groups)
+ if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
+ raise ValueError(
+ "loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group"
+ )
+
+ # Creating mapping from id to parameters.
+ id_map = {
+ old_id: p
+ for old_id, p in zip(
+ chain.from_iterable((g["params"] for g in saved_groups)),
+ chain.from_iterable((g["params"] for g in param_groups)),
+ )
+ }
+
+ # Update parameter groups, setting their 'params' value.
+ def update_group(group, new_group):
+ new_group["params"] = group["params"]
+ return new_group
+
+ updated_groups = [update_group(g, ng) for g, ng in zip(param_groups, saved_groups)]
+
+ optimizer.__dict__.update({"param_groups": updated_groups})
+ return id_map
+
+
+def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict, strict: bool = False):
+ r"""Copies states from `state_dict` into an Optimizer object.
+
+ Args:
+ optimizer(Optimizer): An initialized Optimizer object to be loaded
+ state_dict(dict): A mapping from tensor index (an integer)
+ to its states to be loaded (a mapping from state name to a tensor).
+ id_map(dict): A mapping from tensor index (an integer)
+ to its corresponding parameter (a tensor) whose states will be updated.
+ strict(bool, optional): If set to True, only load the parameters with its id in id_map. Defaults to False.
+ """
+
+ # Ensure that the keys of state_dict are integers.
+ state_dict = {int(k): v for k, v in state_dict.items()}
+
+ def cast(param, value, key=None):
+ r"""Make a deep copy of value, casting all tensors to device of param."""
+ if isinstance(value, torch.Tensor):
+ # Floating-point types are a bit special here. They are the only ones
+ # that are assumed to always match the type of params.
+ # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
+ if key != "step":
+ if param.is_floating_point():
+ value = value.to(param.dtype)
+ value = value.to(param.device)
+ return value
+ elif isinstance(value, dict):
+ return {k: cast(param, v, key=k) for k, v in value.items()}
+ elif isinstance(value, container_abcs.Iterable):
+ return type(value)(cast(param, v) for v in value)
+ else:
+ return value
+
+ # Copy state assigned to params (and cast tensors to appropriate types).
+ # State that is not assigned to params is copied as is (needed for
+ # backward compatibility).
+ new_states = defaultdict(dict)
+ for k, v in state_dict.items():
+ if k in id_map:
+ param = id_map[k]
+ new_states[param] = cast(param, v)
+ elif not strict:
+ new_states[k] = v
+
+ optimizer.state.update(new_states)
+
+
+def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
+ r"""Do the cleaning up work after state_dict has been loaded into optimizer
+
+ Args:
+ optimizer(Optimizer): An optimizer object whose state has just been loaded.
+ """
+
+ # Do the cleaning up as in src code of Pytorch.
+ if Version(torch.__version__) >= Version("2.0.0"):
+ optimizer._patch_step_function() # To support multiprocessing pickle/unpickle
+ else:
+ optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
+ optimizer.defaults.setdefault("differentiable", False)
def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
@@ -365,18 +691,20 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
return False, None
elif checkpoint_path.is_dir():
# check if there is only one a file ending with .index.json in this directory
- index_files = list(checkpoint_path.glob('*.index.*json'))
+ index_files = list(checkpoint_path.glob("*.index.*json"))
# if we found a .index.json file, make sure there is only one
if len(index_files) > 0:
- assert len(
- index_files
- ) == 1, f'Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}'
+ assert (
+ len(index_files) == 1
+ ), f"Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}"
if len(index_files) == 1:
return True, index_files[0]
else:
return False, None
+ else:
+ raise RuntimeError(f"Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.")
def load_state_dict(checkpoint_file_path: Path):
@@ -390,14 +718,17 @@ def load_state_dict(checkpoint_file_path: Path):
dict: state dict.
"""
- assert not is_dtensor_checkpoint(checkpoint_file_path), \
- f'Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline.'
+ assert not is_dtensor_checkpoint(
+ checkpoint_file_path
+ ), f"Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline."
if is_safetensor_checkpoint(checkpoint_file_path):
- assert is_safetensors_available(), \
- f'Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors.'
+ assert (
+ is_safetensors_available()
+ ), f"Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors."
# load with safetensors
from safetensors import safe_open
+
state_dict = {}
with safe_open(checkpoint_file_path, framework="pt", device="cpu") as f:
for k in f.keys():
@@ -406,14 +737,51 @@ def load_state_dict(checkpoint_file_path: Path):
else:
# load with torch
- return torch.load(checkpoint_file_path)
-
+ return torch.load(checkpoint_file_path, map_location=torch.device("cpu"))
-def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
- if variant is not None and len(variant) > 0:
+def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str:
+ if prefix is not None and len(prefix) > 0:
splits = weights_name.split(".")
- splits = splits[:-1] + [variant] + splits[-1:]
+ splits = splits[:-1] + [prefix] + splits[-1:]
weights_name = ".".join(splits)
return weights_name
+
+
+def get_model_base_filenames(prefix: str = None, use_safetensors: bool = False):
+ """
+ generate base model weight filenames
+ """
+ weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
+ weights_name = add_prefix(weights_name, prefix)
+
+ save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
+ save_index_file = add_prefix(save_index_file, prefix)
+
+ return weights_name, save_index_file
+
+
+def get_optimizer_base_filenames(prefix: str = None):
+ """
+ generate base optimizer state filenames
+ """
+ states_name = STATES_NAME
+ states_name = add_prefix(states_name, prefix)
+
+ save_index_file = STATES_INDEX_NAME
+ save_index_file = add_prefix(save_index_file, prefix)
+
+ param_group_file = GROUP_FILE_NAME
+ param_group_file = add_prefix(param_group_file, prefix)
+
+ return states_name, save_index_file, param_group_file
+
+
+def get_shard_filename(weights_name: str, idx: int):
+ """
+ get shard file name
+ """
+ shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
+ shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors")
+ return shard_file
diff --git a/colossalai/cli/__init__.py b/colossalai/cli/__init__.py
index 658e35e4c72e77f7b3161bf86b0c9600a80562e5..c7cb19c193082d54b57c6fdd40815f29636f6b25 100644
--- a/colossalai/cli/__init__.py
+++ b/colossalai/cli/__init__.py
@@ -1,3 +1,3 @@
from .cli import cli
-__all__ = ['cli']
+__all__ = ["cli"]
diff --git a/colossalai/cli/benchmark/__init__.py b/colossalai/cli/benchmark/__init__.py
deleted file mode 100644
index 618ff8c61dd41ec83d567e7cd9103f2aa9921846..0000000000000000000000000000000000000000
--- a/colossalai/cli/benchmark/__init__.py
+++ /dev/null
@@ -1,28 +0,0 @@
-import click
-
-from colossalai.context import Config
-
-from .benchmark import run_benchmark
-from .utils import *
-
-__all__ = ['benchmark']
-
-
-@click.command()
-@click.option("-g", "--gpus", type=int, default=None, help="Total number of devices to use.")
-@click.option("-b", "--batch_size", type=int, default=8, help="Batch size of the input tensor.")
-@click.option("-s", "--seq_len", type=int, default=512, help="Sequence length of the input tensor.")
-@click.option("-d", "--dimension", type=int, default=1024, help="Hidden dimension of the input tensor.")
-@click.option("-w", "--warmup_steps", type=int, default=10, help="The number of warmup steps.")
-@click.option("-p", "--profile_steps", type=int, default=50, help="The number of profiling steps.")
-@click.option("-l", "--layers", type=int, default=2)
-@click.option("-m",
- "--model",
- type=click.Choice(['mlp'], case_sensitive=False),
- default='mlp',
- help="Select the model to benchmark, currently only supports MLP")
-def benchmark(gpus: int, batch_size: int, seq_len: int, dimension: int, warmup_steps: int, profile_steps: int,
- layers: int, model: str):
- args_dict = locals()
- args = Config(args_dict)
- run_benchmark(args)
diff --git a/colossalai/cli/benchmark/benchmark.py b/colossalai/cli/benchmark/benchmark.py
deleted file mode 100644
index 97a9f45722dd6a4c1e316d1e91f27439797ae17a..0000000000000000000000000000000000000000
--- a/colossalai/cli/benchmark/benchmark.py
+++ /dev/null
@@ -1,105 +0,0 @@
-from functools import partial
-from typing import Dict, List
-
-import click
-import torch.multiprocessing as mp
-
-import colossalai
-from colossalai.cli.benchmark.utils import find_all_configs, get_batch_data, profile_model
-from colossalai.context import Config
-from colossalai.context.random import reset_seeds
-from colossalai.core import global_context as gpc
-from colossalai.logging import disable_existing_loggers, get_dist_logger
-from colossalai.testing import free_port
-from colossalai.utils import MultiTimer
-
-from .models import MLP
-
-
-def run_benchmark(args: Config) -> None:
- """
- Run benchmarking with torch.multiprocessing.
- """
-
- # sanity checks
- if args.gpus is None:
- click.echo("Error: --num_gpus is not given")
- exit()
- if args.gpus <= 1:
- click.echo("Warning: tensor parallel will be activated with at least 2 devices.")
-
- click.echo("=== Benchmarking Parameters ===")
- for k, v in args.items():
- click.echo(f'{k}: {v}')
- click.echo('')
-
- config_list = find_all_configs(args.gpus)
-
- avail_ports = [free_port() for _ in range(len(config_list))]
- run_func = partial(run_dist_profiling,
- world_size=args.gpus,
- port_list=avail_ports,
- config_list=config_list,
- hyperparams=args)
- mp.spawn(run_func, nprocs=args.gpus)
-
-
-def run_dist_profiling(rank: int, world_size: int, port_list: List[int], config_list: List[Dict],
- hyperparams: Config) -> None:
- """
- A function executed for profiling, this function should be spawn by torch.multiprocessing.
-
- Args:
- rank (int): rank of the process
- world_size (int): the number of processes
- port_list (List[int]): a list of free ports for initializing distributed networks
- config_list (List[Dict]): a list of configuration
- hyperparams (Config): the hyperparameters given by the user
-
- """
-
- # disable logging for clean output
- disable_existing_loggers()
- logger = get_dist_logger()
- logger.set_level('WARNING')
-
- for config, port in zip(config_list, port_list):
- colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- timer = MultiTimer()
-
- # 1D parallel should be skipped if in_features or out_features is not able to be divided exactly by 1D parallel size.
- if config.parallel.tensor.mode == '1d' and hyperparams.dimension % config.parallel.tensor.size != 0:
- click.echo(
- "1D parallel will be skipped because in_features or out_features is not able to be divided exactly by 1D parallel size."
- )
- continue
-
- if hyperparams.model == 'mlp':
- model = MLP(dim=hyperparams.dimension, layers=hyperparams.layers)
- else:
- if gpc.get_global_rank() == 0:
- click.echo("Error: Invalid argument for --model")
- exit()
-
- data_func = partial(get_batch_data,
- dim=hyperparams.dimension,
- batch_size=hyperparams.batch_size,
- seq_length=hyperparams.seq_len,
- mode=config.parallel.tensor.mode)
-
- fwd_time, bwd_time, max_allocated, max_cached = profile_model(model=model,
- warmup_steps=hyperparams.warmup_steps,
- profile_steps=hyperparams.profile_steps,
- data_func=data_func,
- timer=timer)
-
- gpc.destroy()
- reset_seeds()
-
- if gpc.get_global_rank() == 0:
- config_str = ', '.join([f'{k}: {v}' for k, v in config.parallel.tensor.items()])
- click.echo(f"=== {config_str} ===")
- click.echo(f"Average forward time: {fwd_time}")
- click.echo(f"Average backward time: {bwd_time}")
- click.echo(f"Max allocated GPU memory: {max_allocated}")
- click.echo(f"Max cached GPU memory: {max_cached}\n")
diff --git a/colossalai/cli/benchmark/models.py b/colossalai/cli/benchmark/models.py
deleted file mode 100644
index f8fd1c41a059806891713340f2ea4931ec9726f2..0000000000000000000000000000000000000000
--- a/colossalai/cli/benchmark/models.py
+++ /dev/null
@@ -1,18 +0,0 @@
-import torch
-
-import colossalai.nn as col_nn
-
-
-class MLP(torch.nn.Module):
-
- def __init__(self, dim: int, layers: int):
- super().__init__()
- self.layers = torch.nn.ModuleList()
-
- for _ in range(layers):
- self.layers.append(col_nn.Linear(dim, dim))
-
- def forward(self, x):
- for layer in self.layers:
- x = layer(x)
- return x
diff --git a/colossalai/cli/benchmark/utils.py b/colossalai/cli/benchmark/utils.py
deleted file mode 100644
index 825b795f21f680bcb2e2ea5eee4b328c2e1777db..0000000000000000000000000000000000000000
--- a/colossalai/cli/benchmark/utils.py
+++ /dev/null
@@ -1,158 +0,0 @@
-import math
-import time
-import torch
-
-from colossalai.utils import MultiTimer
-from colossalai.context import ParallelMode, Config
-from typing import List, Dict, Tuple, Callable
-
-
-def get_time_stamp() -> int:
- """
- Return the time stamp for profiling.
-
- Returns:
- time_stamp (int): the time given by time.time()
- """
-
- torch.cuda.synchronize()
- time_stamp = time.time()
- return time_stamp
-
-
-def get_memory_states() -> Tuple[float]:
- """
- Return the memory statistics.
-
- Returns:
- max_allocated (float): the allocated CUDA memory
- max_cached (float): the cached CUDA memory
- """
-
- max_allocated = torch.cuda.max_memory_allocated() / (1024**3)
- max_cached = torch.cuda.max_memory_reserved() / (1024**3)
- torch.cuda.reset_peak_memory_stats()
- torch.cuda.empty_cache()
- return max_allocated, max_cached
-
-
-def find_all_configs(device_cnt: int) -> List[Dict]:
- """
- Find all possible configurations for tensor parallelism
-
- Args:
- device_cnt (int): the number of devices
-
- Returns:
- config_list (List[Dict]): a list of configurations
- """
-
- def _is_square(num):
- # 2D parallel should be implemented with at least 2 devices.
- if num <= 1:
- return False
- return math.floor(math.sqrt(num))**2 == num
-
- def _is_cube(num):
- # 3D parallel should be implemented with at least 2 devices.
- if num <= 1:
- return False
- return math.floor(num**(1. / 3.))**3 == num
-
- config_list = []
-
- # add non-parallel config
- config = dict(parallel=dict(tensor=dict(size=device_cnt, mode=None)))
- config_list.append(config)
-
- # add 1D config
- config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='1d')))
- config_list.append(config)
-
- # add 2D config only if device_cnt is a square
- if _is_square(device_cnt):
- config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2d')))
- config_list.append(config)
-
- # check for 2.5D
- # iterate over depth
- for depth in range(1, device_cnt):
- if device_cnt % depth == 0 and _is_square(device_cnt // depth):
- config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2.5d', depth=depth)))
- config_list.append(config)
-
- # check for 3D if device_cnt is a cube
- if _is_cube(device_cnt):
- config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='3d')))
- config_list.append(config)
-
- config_list = [Config(cfg) for cfg in config_list]
- return config_list
-
-
-def profile_model(model: torch.nn.Module, warmup_steps: int, profile_steps: int, data_func: Callable,
- timer: MultiTimer) -> Tuple[float]:
- """
- Profile the forward and backward of a model
-
- Args:
- model (torch.nn.Module): a PyTorch model
- warmup_steps (int): the number of steps for warmup
- profile_steps (int): the number of steps for profiling
- data_func (Callable): a function to generate random data
- timer (colossalai.utils.Multitimer): a timer instance for time recording
-
- Returns:
- fwd_time (float): the average forward time taken by forward pass in second
- bwd_time (float): the average backward time taken by forward pass in second
- max_allocated (float): the maximum GPU memory allocated in GB
- max_cached (float): the maximum GPU memory cached in GB
- """
-
- def _run_step(data):
- timer.start('forward')
- out = model(data)
- timer.stop('forward', keep_in_history=True)
- timer.start('backward')
- out.mean().backward()
- timer.stop('backward', keep_in_history=True)
-
- data_list = [data_func() for _ in range(warmup_steps)]
- for data in data_list:
- _run_step(data)
- timer.reset('forward')
- timer.reset('backward')
-
- for _ in range(profile_steps):
- data = data_func()
- _run_step(data)
-
- max_allocated, max_cached = get_memory_states()
- fwd_time = timer.get_timer('forward').get_history_mean()
- bwd_time = timer.get_timer('backward').get_history_mean()
- return fwd_time, bwd_time, max_allocated, max_cached
-
-
-def get_batch_data(dim: int, batch_size: int, seq_length: int, mode: ParallelMode) -> torch.Tensor:
- """
- Return a random data of shape (batch_size, seq_length, dim) for profiling.
-
- Args:
- dim (int): hidden size
- batch_size (int): the number of data samples
- seq_length (int): the number of tokens
- mode (ParallelMode): Colossal-AI ParallelMode enum
-
- Returns:
- data (torch.Tensor): random data
- """
-
- if mode in ['2d', '2.5d']:
- batch_size = batch_size // 2
- dim = dim // 2
- elif mode == '3d':
- batch_size = batch_size // 4
- dim = dim // 2
-
- data = torch.rand(batch_size, seq_length, dim).cuda()
- return data
diff --git a/colossalai/cli/check/__init__.py b/colossalai/cli/check/__init__.py
index a86b32bb6a181a7c75bfa6682c2f7514169d9a7f..7c26ab6ade6cdc7238608232ce07c07fd977888f 100644
--- a/colossalai/cli/check/__init__.py
+++ b/colossalai/cli/check/__init__.py
@@ -1,11 +1,12 @@
import click
+
from .check_installation import check_installation
-__all__ = ['check']
+__all__ = ["check"]
@click.command(help="Check if Colossal-AI is correct based on the given option")
-@click.option('-i', '--installation', is_flag=True, help="Check if Colossal-AI is built correctly")
+@click.option("-i", "--installation", is_flag=True, help="Check if Colossal-AI is built correctly")
def check(installation):
if installation:
check_installation()
diff --git a/colossalai/cli/check/check_installation.py b/colossalai/cli/check/check_installation.py
index cb3dbbc09301012aa662264f241e1fce89470d39..772c513ffa06b1ffa148f1e95f956964e14a8716 100644
--- a/colossalai/cli/check/check_installation.py
+++ b/colossalai/cli/check/check_installation.py
@@ -9,7 +9,7 @@ import colossalai
def to_click_output(val):
# installation check output to understandable symbols for readability
- VAL_TO_SYMBOL = {True: u'\u2713', False: 'x', None: 'N/A'}
+ VAL_TO_SYMBOL = {True: "\u2713", False: "x", None: "N/A"}
if val in VAL_TO_SYMBOL:
return VAL_TO_SYMBOL[val]
@@ -31,7 +31,7 @@ def check_installation():
found_aot_cuda_ext = _check_aot_built_cuda_extension_installed()
cuda_version = _check_cuda_version()
torch_version, torch_cuda_version = _check_torch_version()
- colossalai_verison, prebuilt_torch_version_required, prebuilt_cuda_version_required = _parse_colossalai_version()
+ colossalai_version, prebuilt_torch_version_required, prebuilt_cuda_version_required = _parse_colossalai_version()
# if cuda_version is None, that means either
# CUDA_HOME is not found, thus cannot compare the version compatibility
@@ -55,9 +55,9 @@ def check_installation():
else:
torch_compatibility = _is_compatible([torch_version, prebuilt_torch_version_required])
- click.echo(f'#### Installation Report ####')
- click.echo(f'\n------------ Environment ------------')
- click.echo(f"Colossal-AI version: {to_click_output(colossalai_verison)}")
+ click.echo(f"#### Installation Report ####")
+ click.echo(f"\n------------ Environment ------------")
+ click.echo(f"Colossal-AI version: {to_click_output(colossalai_version)}")
click.echo(f"PyTorch version: {to_click_output(torch_version)}")
click.echo(f"System CUDA version: {to_click_output(cuda_version)}")
click.echo(f"CUDA version required by PyTorch: {to_click_output(torch_cuda_version)}")
@@ -69,7 +69,7 @@ def check_installation():
f"3. If the CUDA version required by PyTorch is N/A, you probably did not install a CUDA-compatible PyTorch. This value is give by torch.version.cuda and you can go to https://pytorch.org/get-started/locally/ to download the correct version."
)
- click.echo(f'\n------------ CUDA Extensions AOT Compilation ------------')
+ click.echo(f"\n------------ CUDA Extensions AOT Compilation ------------")
click.echo(f"Found AOT CUDA Extension: {to_click_output(found_aot_cuda_ext)}")
click.echo(f"PyTorch version used for AOT compilation: {to_click_output(prebuilt_torch_version_required)}")
click.echo(f"CUDA version used for AOT compilation: {to_click_output(prebuilt_cuda_version_required)}")
@@ -81,7 +81,7 @@ def check_installation():
click.echo(f"2. If AOT compilation is not enabled, stay calm as the CUDA kernels can still be built during runtime")
click.echo(f"\n------------ Compatibility ------------")
- click.echo(f'PyTorch version match: {to_click_output(torch_compatibility)}')
+ click.echo(f"PyTorch version match: {to_click_output(torch_compatibility)}")
click.echo(f"System and PyTorch CUDA version match: {to_click_output(sys_torch_cuda_compatibility)}")
click.echo(f"System and Colossal-AI CUDA version match: {to_click_output(sys_colossalai_cuda_compatibility)}")
click.echo(f"")
@@ -106,12 +106,12 @@ def _is_compatible(versions):
return False
# split version into [major, minor, patch]
- versions = [version.split('.') for version in versions]
+ versions = [version.split(".") for version in versions]
for version in versions:
if len(version) == 2:
# x means unknown
- version.append('x')
+ version.append("x")
for idx, version_values in enumerate(zip(*versions)):
equal = len(set(version_values)) == 1
@@ -137,15 +137,15 @@ def _parse_colossalai_version():
# 1. X.X.X+torchX.XXcuXX.X (when colossalai is installed with CUDA extensions)
# 2. X.X.X (when colossalai is not installed with CUDA extensions)
# where X represents an integer.
- colossalai_verison = colossalai.__version__.split('+')[0]
+ colossalai_version = colossalai.__version__.split("+")[0]
try:
- torch_version_for_aot_build = colossalai.__version__.split('torch')[1].split('cu')[0]
- cuda_version_for_aot_build = colossalai.__version__.split('cu')[1]
+ torch_version_for_aot_build = colossalai.__version__.split("torch")[1].split("cu")[0]
+ cuda_version_for_aot_build = colossalai.__version__.split("cu")[1]
except:
torch_version_for_aot_build = None
cuda_version_for_aot_build = None
- return colossalai_verison, torch_version_for_aot_build, cuda_version_for_aot_build
+ return colossalai_version, torch_version_for_aot_build, cuda_version_for_aot_build
def _check_aot_built_cuda_extension_installed():
@@ -156,7 +156,6 @@ def _check_aot_built_cuda_extension_installed():
JIT (just-in-time) compilation will build CUDA extensions to `~/.cache/colossalai/torch_extensions` during runtime.
"""
try:
- import colossalai._C.fused_optim
found_aot_cuda_ext = True
except ImportError:
found_aot_cuda_ext = False
@@ -175,14 +174,14 @@ def _check_torch_version():
# torch version can be of two formats
# - 1.13.1+cu113
# - 1.13.1.devxxx
- torch_version = torch.__version__.split('+')[0]
- torch_version = '.'.join(torch_version.split('.')[:3])
+ torch_version = torch.__version__.split("+")[0]
+ torch_version = ".".join(torch_version.split(".")[:3])
# get cuda version in pytorch build
try:
torch_cuda_major = torch.version.cuda.split(".")[0]
torch_cuda_minor = torch.version.cuda.split(".")[1]
- torch_cuda_version = f'{torch_cuda_major}.{torch_cuda_minor}'
+ torch_cuda_version = f"{torch_cuda_major}.{torch_cuda_minor}"
except:
torch_cuda_version = None
@@ -208,7 +207,7 @@ def _check_cuda_version():
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
- cuda_version = f'{bare_metal_major}.{bare_metal_minor}'
+ cuda_version = f"{bare_metal_major}.{bare_metal_minor}"
except:
cuda_version = None
return cuda_version
diff --git a/colossalai/cli/cli.py b/colossalai/cli/cli.py
index a94e1150e49fc00a210c20b32a9bfc85eda66aa6..0d94fe59f8aef275c4bae22a6ea29fa8ec0c1978 100644
--- a/colossalai/cli/cli.py
+++ b/colossalai/cli/cli.py
@@ -1,12 +1,10 @@
import click
-from .benchmark import benchmark
from .check import check
from .launcher import run
-class Arguments():
-
+class Arguments:
def __init__(self, arg_dict):
for k, v in arg_dict.items():
self.__dict__[k] = v
@@ -19,7 +17,6 @@ def cli():
cli.add_command(run)
cli.add_command(check)
-cli.add_command(benchmark)
-if __name__ == '__main__':
+if __name__ == "__main__":
cli()
diff --git a/colossalai/cli/launcher/__init__.py b/colossalai/cli/launcher/__init__.py
index 8d9ec147d401a2e5d055852e661e189985b6db6e..0f9ead6495dbe242873d0bfded0a020518bf2d8c 100644
--- a/colossalai/cli/launcher/__init__.py
+++ b/colossalai/cli/launcher/__init__.py
@@ -5,56 +5,81 @@ from colossalai.context import Config
from .run import launch_multi_processes
-@click.command(help="Launch distributed training on a single node or multiple nodes",
- context_settings=dict(ignore_unknown_options=True))
-@click.option("-H",
- "-host",
- "--host",
- type=str,
- default=None,
- help="the list of hostnames to launch in the format ,")
+@click.command(
+ help="Launch distributed training on a single node or multiple nodes",
+ context_settings=dict(ignore_unknown_options=True),
+)
+@click.option(
+ "-H",
+ "-host",
+ "--host",
+ type=str,
+ default=None,
+ help="the list of hostnames to launch in the format ,",
+)
@click.option(
"--hostfile",
type=str,
default=None,
- help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname")
-@click.option("--include",
- type=str,
- default=None,
- help="Specify computing devices to use during execution. String format is ,,"
- " only effective when used with --hostfile.")
+ help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname",
+)
+@click.option(
+ "--include",
+ type=str,
+ default=None,
+ help="Specify computing devices to use during execution. String format is ,,"
+ " only effective when used with --hostfile.",
+)
@click.option(
"--exclude",
type=str,
default=None,
- help=
- "Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --includ,"
- " only effective when used with --hostfile.")
-@click.option("--num_nodes",
- type=int,
- default=-1,
- help="Total number of worker nodes to use, only effective when used with --hostfile.")
+ help="Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include,"
+ " only effective when used with --hostfile.",
+)
+@click.option(
+ "--num_nodes",
+ type=int,
+ default=-1,
+ help="Total number of worker nodes to use, only effective when used with --hostfile.",
+)
@click.option("--nproc_per_node", type=int, default=None, help="Number of GPUs to use on each node.")
-@click.option("--master_port",
- type=int,
- default=29500,
- help="(optional) Port used by PyTorch distributed for communication during distributed training.")
-@click.option("--master_addr",
- type=str,
- default="127.0.0.1",
- help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.")
+@click.option(
+ "--master_port",
+ type=int,
+ default=29500,
+ help="(optional) Port used by PyTorch distributed for communication during distributed training.",
+)
+@click.option(
+ "--master_addr",
+ type=str,
+ default="127.0.0.1",
+ help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.",
+)
@click.option(
"--extra_launch_args",
type=str,
default=None,
- help=
- "Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. "
- "This will be converted to --arg1=1 --arg2=2 during execution")
+ help="Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. "
+ "This will be converted to --arg1=1 --arg2=2 during execution",
+)
@click.option("--ssh-port", type=int, default=None, help="(optional) the port used for ssh connection")
@click.argument("user_script", type=str)
-@click.argument('user_args', nargs=-1)
-def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: str, exclude: str, master_addr: str,
- master_port: int, extra_launch_args: str, ssh_port: int, user_script: str, user_args: str) -> None:
+@click.argument("user_args", nargs=-1)
+def run(
+ host: str,
+ hostfile: str,
+ num_nodes: int,
+ nproc_per_node: int,
+ include: str,
+ exclude: str,
+ master_addr: str,
+ master_port: int,
+ extra_launch_args: str,
+ ssh_port: int,
+ user_script: str,
+ user_args: str,
+) -> None:
"""
To launch multiple processes on a single node or multiple nodes via command line.
@@ -77,8 +102,8 @@ def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include:
# run with hostfile excluding the hosts selected
colossalai run --hostfile --master_addr host1 --exclude host2 --nprocs_per_node 4 train.py
"""
- if not user_script.endswith('.py'):
- click.echo(f'Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help')
+ if not user_script.endswith(".py"):
+ click.echo(f"Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help")
exit()
args_dict = locals()
diff --git a/colossalai/cli/launcher/hostinfo.py b/colossalai/cli/launcher/hostinfo.py
index 065cbc37101f9705319d96120779edcbbbf6dde9..684f64f59d28b945d930602e988a2847670eda6b 100644
--- a/colossalai/cli/launcher/hostinfo.py
+++ b/colossalai/cli/launcher/hostinfo.py
@@ -1,5 +1,4 @@
import socket
-from typing import List
class HostInfo:
@@ -34,11 +33,11 @@ class HostInfo:
"""
if port is None:
- port = 22 # no port specified, lets just use the ssh port
+ port = 22 # no port specified, lets just use the ssh port
# socket.getfqdn("127.0.0.1") does not return localhost
# on some users' machines
- # thus, we directly return True if hostname is locahost, 127.0.0.1 or 0.0.0.0
+ # thus, we directly return True if hostname is localhost, 127.0.0.1 or 0.0.0.0
if hostname in ("localhost", "127.0.0.1", "0.0.0.0"):
return True
@@ -46,14 +45,11 @@ class HostInfo:
localhost = socket.gethostname()
localaddrs = socket.getaddrinfo(localhost, port)
targetaddrs = socket.getaddrinfo(hostname, port)
- for (family, socktype, proto, canonname, sockaddr) in localaddrs:
- for (rfamily, rsocktype, rproto, rcanonname, rsockaddr) in targetaddrs:
- if rsockaddr[0] == sockaddr[0]:
- return True
- return False
+
+ return localaddrs == targetaddrs
def __str__(self):
- return f'hostname: {self.hostname}, port: {self.port}'
+ return f"hostname: {self.hostname}, port: {self.port}"
def __repr__(self):
return self.__str__()
diff --git a/colossalai/cli/launcher/multinode_runner.py b/colossalai/cli/launcher/multinode_runner.py
index a51e1e371f13df492c16ac5b554d02ba6491d65b..99c4db40684480e19933b6a4da6fca90cee70aff 100644
--- a/colossalai/cli/launcher/multinode_runner.py
+++ b/colossalai/cli/launcher/multinode_runner.py
@@ -7,8 +7,13 @@ import fabric
from .hostinfo import HostInfo, HostInfoList
-def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Connection,
- send_conn: mp_connection.Connection, env: dict) -> None:
+def run_on_host(
+ hostinfo: HostInfo,
+ workdir: str,
+ recv_conn: mp_connection.Connection,
+ send_conn: mp_connection.Connection,
+ env: dict,
+) -> None:
"""
Use fabric connection to execute command on local or remote hosts.
@@ -22,14 +27,14 @@ def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Conne
fab_conn = fabric.Connection(hostinfo.hostname, port=hostinfo.port)
finish = False
- env_msg = ' '.join([f'{k}=\"{v}\"' for k, v in env.items()])
+ env_msg = " ".join([f'{k}="{v}"' for k, v in env.items()])
# keep listening until exit
while not finish:
# receive cmd
cmds = recv_conn.recv()
- if cmds == 'exit':
+ if cmds == "exit":
# exit from the loop
finish = True
break
@@ -46,12 +51,12 @@ def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Conne
else:
# execute on the remote machine
fab_conn.run(cmds, hide=False)
- send_conn.send('success')
+ send_conn.send("success")
except Exception as e:
click.echo(
f"Error: failed to run {cmds} on {hostinfo.hostname}, is localhost: {hostinfo.is_local_host}, exception: {e}"
)
- send_conn.send('failure')
+ send_conn.send("failure")
# shutdown
send_conn.send("finish")
@@ -96,8 +101,7 @@ class MultiNodeRunner:
cmd (str): the command to execute
"""
- assert hostinfo.hostname in self.master_send_conns, \
- f'{hostinfo} is not found in the current connections'
+ assert hostinfo.hostname in self.master_send_conns, f"{hostinfo} is not found in the current connections"
conn = self.master_send_conns[hostinfo.hostname]
conn.send(cmd)
@@ -107,14 +111,14 @@ class MultiNodeRunner:
"""
for hostname, conn in self.master_send_conns.items():
- conn.send('exit')
+ conn.send("exit")
def recv_from_all(self) -> dict:
"""
Receive messages from all hosts
Returns:
- msg_from_node (dict): a dictionry which contains messages from each node
+ msg_from_node (dict): a dictionary which contains messages from each node
"""
msg_from_node = dict()
diff --git a/colossalai/cli/launcher/run.py b/colossalai/cli/launcher/run.py
index 6411b4302e95d25efc5a55da1b523b76ee6ee1e3..88f70f02ec27fb81a8fe5ab21973810638500ee4 100644
--- a/colossalai/cli/launcher/run.py
+++ b/colossalai/cli/launcher/run.py
@@ -12,7 +12,7 @@ from .hostinfo import HostInfo, HostInfoList
from .multinode_runner import MultiNodeRunner
# Constants that define our syntax
-NODE_SEP = ','
+NODE_SEP = ","
def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
@@ -34,12 +34,12 @@ def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}")
exit()
- with open(hostfile_path, 'r') as fd:
+ with open(hostfile_path, "r") as fd:
device_pool = HostInfoList()
for line in fd.readlines():
line = line.strip()
- if line == '':
+ if line == "":
# skip empty lines
continue
@@ -56,7 +56,7 @@ def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str=None) -> HostInfoList:
- '''Parse an inclusion or exclusion string and filter a hostfile dictionary.
+ """Parse an inclusion or exclusion string and filter a hostfile dictionary.
Examples:
include_str="worker-0,worker-1" will execute jobs only on worker-0 and worker-1.
@@ -69,7 +69,7 @@ def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str
Returns:
filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion
- '''
+ """
# Ensure include/exclude are mutually exclusive
if include_str and exclude_str:
@@ -136,16 +136,16 @@ def get_launch_command(
for k, v in arg_dict.items():
if v:
- ret.append(f'--{k}={v}')
+ ret.append(f"--{k}={v}")
else:
- ret.append(f'--{k}')
+ ret.append(f"--{k}")
return ret
if extra_launch_args:
extra_launch_args_dict = dict()
- for arg in extra_launch_args.split(','):
- if '=' in arg:
- k, v = arg.split('=')
+ for arg in extra_launch_args.split(","):
+ if "=" in arg:
+ k, v = arg.split("=")
extra_launch_args_dict[k] = v
else:
extra_launch_args_dict[arg] = None
@@ -154,19 +154,23 @@ def get_launch_command(
extra_launch_args = dict()
torch_version = version.parse(torch.__version__)
- assert torch_version.major == 1
+ assert torch_version.major >= 1
- if torch_version.minor < 9:
+ if torch_version.major == 1 and torch_version.minor < 9:
+ # torch distributed launch cmd with torch < 1.9
cmd = [
- sys.executable, "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}",
- f"--master_addr={master_addr}", f"--master_port={master_port}", f"--nnodes={num_nodes}",
- f"--node_rank={node_rank}"
+ sys.executable,
+ "-m",
+ "torch.distributed.launch",
+ f"--nproc_per_node={nproc_per_node}",
+ f"--master_addr={master_addr}",
+ f"--master_port={master_port}",
+ f"--nnodes={num_nodes}",
+ f"--node_rank={node_rank}",
]
else:
# extra launch args for torch distributed launcher with torch >= 1.9
- default_torchrun_rdzv_args = dict(rdzv_backend="c10d",
- rdzv_endpoint=f"{master_addr}:{master_port}",
- rdzv_id="colossalai-default-job")
+ default_torchrun_rdzv_args = dict(master_addr=master_addr, master_port=master_port)
# update rdzv arguments
for key in default_torchrun_rdzv_args.keys():
@@ -174,19 +178,28 @@ def get_launch_command(
value = extra_launch_args.pop(key)
default_torchrun_rdzv_args[key] = value
- if torch_version.minor < 10:
+ if torch_version.major == 1 and torch_version.minor == 9:
+ # torch distributed launch cmd with torch == 1.9
cmd = [
- sys.executable, "-m", "torch.distributed.run", f"--nproc_per_node={nproc_per_node}",
- f"--nnodes={num_nodes}", f"--node_rank={node_rank}"
+ sys.executable,
+ "-m",
+ "torch.distributed.run",
+ f"--nproc_per_node={nproc_per_node}",
+ f"--nnodes={num_nodes}",
+ f"--node_rank={node_rank}",
]
else:
+ # torch distributed launch cmd with torch > 1.9
cmd = [
- "torchrun", f"--nproc_per_node={nproc_per_node}", f"--nnodes={num_nodes}", f"--node_rank={node_rank}"
+ "torchrun",
+ f"--nproc_per_node={nproc_per_node}",
+ f"--nnodes={num_nodes}",
+ f"--node_rank={node_rank}",
]
cmd += _arg_dict_to_list(default_torchrun_rdzv_args)
cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args
- cmd = ' '.join(cmd)
+ cmd = " ".join(cmd)
return cmd
@@ -250,33 +263,39 @@ def launch_multi_processes(args: Config) -> None:
# run on local node if not hosts or hostfile is given
# add local node to host info list
active_device_pool = HostInfoList()
- localhost_info = HostInfo(hostname='127.0.0.1', port=args.ssh_port)
+ localhost_info = HostInfo(hostname="127.0.0.1", port=args.ssh_port)
active_device_pool.append(localhost_info)
# launch distributed processes
runner = MultiNodeRunner()
- curr_path = os.path.abspath('.')
+ curr_path = os.path.abspath(".")
# collect current path env
env = dict()
for k, v in os.environ.items():
# do not support multi-line env var
- if v and '\n' not in v:
+ if v and "\n" not in v:
env[k] = v
# establish remote connection
runner.connect(host_info_list=active_device_pool, workdir=curr_path, env=env)
+ # overwrite master addr when num_nodes > 1 and not specified
+ if len(active_device_pool) > 1 and args.master_addr == "127.0.0.1":
+ args.master_addr = active_device_pool.hostinfo_list[0].hostname
+
# execute distributed launching command
for node_id, hostinfo in enumerate(active_device_pool):
- cmd = get_launch_command(master_addr=args.master_addr,
- master_port=args.master_port,
- nproc_per_node=args.nproc_per_node,
- user_script=args.user_script,
- user_args=args.user_args,
- node_rank=node_id,
- num_nodes=len(active_device_pool),
- extra_launch_args=args.extra_launch_args)
+ cmd = get_launch_command(
+ master_addr=args.master_addr,
+ master_port=args.master_port,
+ nproc_per_node=args.nproc_per_node,
+ user_script=args.user_script,
+ user_args=args.user_args,
+ node_rank=node_id,
+ num_nodes=len(active_device_pool),
+ extra_launch_args=args.extra_launch_args,
+ )
runner.send(hostinfo=hostinfo, cmd=cmd)
# start training
@@ -298,7 +317,7 @@ def launch_multi_processes(args: Config) -> None:
# receive the stop status
msg_from_node = runner.recv_from_all()
- # printe node status
+ # print node status
click.echo("\n====== Stopping All Nodes =====")
for hostname, msg in msg_from_node.items():
click.echo(f"{hostname}: {msg}")
diff --git a/colossalai/cluster/__init__.py b/colossalai/cluster/__init__.py
index 2fbdfd3cc9996b1044720ef0c1669f5f67fbe8b3..b8176feb647b87c4c7da366a74f416d246d96fd8 100644
--- a/colossalai/cluster/__init__.py
+++ b/colossalai/cluster/__init__.py
@@ -1,5 +1,6 @@
from .device_mesh_manager import DeviceMeshManager
from .dist_coordinator import DistCoordinator
from .process_group_manager import ProcessGroupManager
+from .process_group_mesh import ProcessGroupMesh
-__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager']
+__all__ = ["DistCoordinator", "ProcessGroupManager", "DeviceMeshManager", "ProcessGroupMesh"]
diff --git a/colossalai/cluster/device_mesh_manager.py b/colossalai/cluster/device_mesh_manager.py
index 8754baa19792adf6c5ec79c115d36fac9a3f3c5d..e35aca5f4d7e827512d7b618f90f6dd0b86afd61 100644
--- a/colossalai/cluster/device_mesh_manager.py
+++ b/colossalai/cluster/device_mesh_manager.py
@@ -10,13 +10,14 @@ from colossalai.device.device_mesh import DeviceMesh
@dataclass
class DeviceMeshInfo:
- '''
+ """
This class is used to store the information used to initialize the device mesh.
Args:
physical_ids (List[int]): The physical ids of the current booster. For example, if we have the last 4 GPUs on a 8-devices cluster, then the physical ids should be [4, 5, 6, 7].
mesh_shapes (List[Union[torch.Size, List[int], Tuple[int]]]): The shape of the mesh. For example, if we have 4 GPUs and we want to use 2D mesh with mesh shape [2, 2], then the mesh shape should be [2, 2].
- '''
+ """
+
physical_ids: List[int]
mesh_shape: Union[torch.Size, List[int], Tuple[int]] = None
@@ -24,16 +25,18 @@ class DeviceMeshInfo:
if self.mesh_shape is not None:
world_size = len(self.physical_ids)
mesh_shape_numel = torch.Size(self.mesh_shape).numel()
- assert world_size == mesh_shape_numel, f'the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}'
+ assert (
+ world_size == mesh_shape_numel
+ ), f"the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}"
def initialize_device_mesh(device_mesh_info: DeviceMeshInfo):
- '''
+ """
This method is used to initialize the device mesh.
Args:
device_mesh_info (DeviceMeshInfo): The information used to initialize device mesh.
- '''
+ """
# parse the device mesh info
physical_devices = device_mesh_info.physical_ids
physical_mesh = torch.tensor(physical_devices)
@@ -67,13 +70,13 @@ class DeviceMeshManager:
Args:
name (str): name of the device mesh
device_mesh_info (DeviceMeshInfo): the information used to initialize the device mesh
- """
+ """
if name not in self.device_mesh_store:
device_mesh = initialize_device_mesh(device_mesh_info)
self.device_mesh_store[name] = device_mesh
return device_mesh
else:
- raise ValueError(f'Device mesh {name} already exists.')
+ raise ValueError(f"Device mesh {name} already exists.")
def get(self, name: str) -> DeviceMesh:
"""
@@ -88,7 +91,7 @@ class DeviceMeshManager:
if name in self.device_mesh_store:
return self.device_mesh_store[name]
else:
- raise ValueError(f'Device mesh {name} does not exist.')
+ raise ValueError(f"Device mesh {name} does not exist.")
def destroy(self, name: str) -> None:
"""
@@ -103,7 +106,7 @@ class DeviceMeshManager:
dist.destroy_process_group(pg)
del self.device_mesh_store[name]
else:
- raise ValueError(f'Device mesh {name} does not exist.')
+ raise ValueError(f"Device mesh {name} does not exist.")
def destroy_all(self):
"""
diff --git a/colossalai/cluster/dist_coordinator.py b/colossalai/cluster/dist_coordinator.py
index 99dde810e11251e16235573f6dd88e68b74a64b3..98191747e5b368b68206a41645e9b7d8e9fae4a3 100644
--- a/colossalai/cluster/dist_coordinator.py
+++ b/colossalai/cluster/dist_coordinator.py
@@ -20,14 +20,16 @@ class DistCoordinator(metaclass=SingletonMeta):
- master: the process with rank 0
- node master: the process with local rank 0 on the current node
- Example:
- >>> from colossalai.cluster.dist_coordinator import DistCoordinator
- >>> coordinator = DistCoordinator()
- >>>
- >>> if coordinator.is_master():
- >>> do_something()
- >>>
- >>> coordinator.print_on_master('hello world')
+
+ ```python
+ from colossalai.cluster.dist_coordinator import DistCoordinator
+ coordinator = DistCoordinator()
+
+ if coordinator.is_master():
+ do_something()
+
+ coordinator.print_on_master('hello world')
+ ```
Attributes:
rank (int): the rank of the current process
@@ -36,12 +38,13 @@ class DistCoordinator(metaclass=SingletonMeta):
"""
def __init__(self):
- assert dist.is_initialized(
- ), 'Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first.'
+ assert (
+ dist.is_initialized()
+ ), "Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first."
self._rank = dist.get_rank()
self._world_size = dist.get_world_size()
# this is often passed by launchers such as torchrun
- self._local_rank = os.environ.get('LOCAL_RANK', -1)
+ self._local_rank = os.environ.get("LOCAL_RANK", -1)
@property
def rank(self) -> int:
@@ -59,7 +62,9 @@ class DistCoordinator(metaclass=SingletonMeta):
"""
Assert that the local rank is set. This is often passed by launchers such as torchrun.
"""
- assert self.local_rank >= 0, 'The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process.'
+ assert (
+ self.local_rank >= 0
+ ), "The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process."
def is_master(self, process_group: ProcessGroup = None) -> bool:
"""
@@ -128,11 +133,13 @@ class DistCoordinator(metaclass=SingletonMeta):
other processes in the same process group. This is often useful when downloading is required
as we only want to download in one process to prevent file corruption.
- Example:
- >>> from colossalai.cluster import DistCoordinator
- >>> dist_coordinator = DistCoordinator()
- >>> with dist_coordinator.priority_execution():
- >>> dataset = CIFAR10(root='./data', download=True)
+
+ ```python
+ from colossalai.cluster import DistCoordinator
+ dist_coordinator = DistCoordinator()
+ with dist_coordinator.priority_execution():
+ dataset = CIFAR10(root='./data', download=True)
+ ```
Args:
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
@@ -171,19 +178,19 @@ class DistCoordinator(metaclass=SingletonMeta):
"""
A function wrapper that only executes the wrapped function on the master process (rank 0).
- Example:
- >>> from colossalai.cluster import DistCoordinator
- >>> dist_coordinator = DistCoordinator()
- >>>
- >>> @dist_coordinator.on_master_only()
- >>> def print_on_master(msg):
- >>> print(msg)
+ ```python
+ from colossalai.cluster import DistCoordinator
+ dist_coordinator = DistCoordinator()
+
+ @dist_coordinator.on_master_only()
+ def print_on_master(msg):
+ print(msg)
+ ```
"""
is_master = self.is_master(process_group)
- # define an inner functiuon
+ # define an inner function
def decorator(func):
-
@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_master:
diff --git a/colossalai/cluster/process_group_manager.py b/colossalai/cluster/process_group_manager.py
index e52661846f3ed6d25602252886401837613a75e3..68106b5031265fa1edaf5b9d70d510adaa5318cb 100644
--- a/colossalai/cluster/process_group_manager.py
+++ b/colossalai/cluster/process_group_manager.py
@@ -19,7 +19,7 @@ class ProcessGroupManager:
def __init__(self):
self.pg_store = dict()
- def create_process_group(self, name: str, ranks: List[int], backend: str = 'nccl') -> ProcessGroup:
+ def create_process_group(self, name: str, ranks: List[int], backend: str = "nccl") -> ProcessGroup:
"""
Get a process group by name. If the process group does not exist, it will be created.
@@ -36,7 +36,7 @@ class ProcessGroupManager:
self.pg_store[name] = pg
return pg
else:
- raise ValueError(f'Process group {name} already exists.')
+ raise ValueError(f"Process group {name} already exists.")
def get(self, name: str) -> ProcessGroup:
"""
@@ -51,7 +51,7 @@ class ProcessGroupManager:
if name in self.pg_store:
return self.pg_store[name]
else:
- raise ValueError(f'Process group {name} does not exist.')
+ raise ValueError(f"Process group {name} does not exist.")
def destroy(self, name: str) -> None:
"""
@@ -64,7 +64,7 @@ class ProcessGroupManager:
dist.destroy_process_group(self.pg_store[name])
del self.pg_store[name]
else:
- raise ValueError(f'Process group {name} does not exist.')
+ raise ValueError(f"Process group {name} does not exist.")
def destroy_all(self) -> None:
"""
diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..3885bc96256163f29fa197455b3d14fd5c2582a8
--- /dev/null
+++ b/colossalai/cluster/process_group_mesh.py
@@ -0,0 +1,208 @@
+import itertools
+from functools import reduce
+from operator import mul
+from typing import Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch.distributed as dist
+from torch.distributed import ProcessGroup
+
+
+def prod(nums: List[int]) -> int:
+ """Product of a list of numbers.
+
+ Args:
+ nums (List[int]): A list of numbers.
+
+ Returns:
+ int: The product of the numbers.
+ """
+ return reduce(mul, nums)
+
+
+class ProcessGroupMesh:
+ """A helper class to manage the process group mesh. It only describes how to organize process groups, and it's decoupled with parallel method.
+ It just initialize process groups and cache them. The parallel method should manage them and use them to do the parallel computation.
+
+ We use a ND-tuple to represent the process group mesh. And a ND-coordinate is to represent each process.
+ For example, ``(0, 1, 0)`` represents the process whose rank is 2 in a 3D process group mesh with size ``(2, 2, 2)``.
+
+ Args:
+ *size (int): The size of each dimension of the process group mesh. The product of the size must be equal to the world size.
+
+ Attributes:
+ shape (Tuple[int, ...]): The shape of the process group mesh.
+ rank (int): The rank of the current process.
+ """
+
+ def __init__(self, *size: int) -> None:
+ assert dist.is_initialized(), "Please initialize torch.distributed first."
+ assert prod(size) == dist.get_world_size(), "The product of the size must be equal to the world size."
+ self._shape = size
+ self._rank = dist.get_rank()
+ self._coord = ProcessGroupMesh.unravel(self._rank, self._shape)
+ self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {}
+ self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}
+
+ @property
+ def shape(self) -> Tuple[int, ...]:
+ return self._shape
+
+ @property
+ def rank(self) -> int:
+ return self._rank
+
+ def size(self, dim: Optional[int] = None) -> Union[int, Tuple[int, ...]]:
+ """Get the size of the process group mesh.
+
+ Args:
+ dim (Optional[int], optional): Dimension of the process group mesh. `None` means all dimensions. Defaults to None.
+
+ Returns:
+ Union[int, Tuple[int, ...]]: Size of the target dimension or the whole process group mesh.
+ """
+ if dim is None:
+ return self._shape
+ else:
+ return self._shape[dim]
+
+ def coordinate(self, dim: Optional[int] = None) -> Union[int, Tuple[int, ...]]:
+ """Get the coordinate of the process group mesh.
+
+ Args:
+ dim (Optional[int], optional): Dimension of the process group mesh. `None` means all dimensions. Defaults to None.
+
+ Returns:
+ Union[int, Tuple[int, ...]]: Coordinate of the target dimension or the whole process group mesh.
+ """
+ if dim is None:
+ return self._coord
+ else:
+ return self._coord[dim]
+
+ @staticmethod
+ def unravel(rank: int, shape: Tuple[int, ...]) -> Tuple[int, ...]:
+ """Convert a rank to a coordinate.
+
+ Args:
+ rank (int): Rank to be converted.
+ shape (Tuple[int, ...]): Shape of the process group mesh.
+
+ Returns:
+ Tuple[int, ...]: Coordinate of the rank.
+ """
+ return np.unravel_index(rank, shape)
+
+ @staticmethod
+ def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = "raise") -> int:
+ """Convert a coordinate to a rank.
+ mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html.
+ with wrap, index out of range would be wrapped around.
+ For instance, ravel((0, i, 0), (1, 2, 1), 'wrap') returns (i % 2)
+
+ Args:
+ coords (Tuple[int, ...]): Coordinate to be converted.
+ shape (Tuple[int, ...]): Shape of the process group mesh.
+ mode (Optional[str]): The mode for numpy.ravel_multi_index.
+
+ Returns:
+ int: Rank of the coordinate.
+ """
+
+ assert mode in ["raise", "wrap", "clip"]
+ return np.ravel_multi_index(coord, shape, mode)
+
+ def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
+ """Get the process group with the given ranks. It the process group doesn't exist, it will be created.
+
+ Args:
+ ranks_in_group (List[int]): Ranks in the process group.
+ backend (Optional[str], optional): Backend of the process group. Defaults to None.
+
+ Returns:
+ ProcessGroup: The process group with the given ranks.
+ """
+ ranks_in_group = sorted(ranks_in_group)
+ if tuple(ranks_in_group) not in self._group_to_ranks:
+ group = dist.new_group(ranks_in_group, backend=backend)
+ self._ranks_to_group[tuple(ranks_in_group)] = group
+ self._group_to_ranks[group] = tuple(ranks_in_group)
+ return self._ranks_to_group[tuple(ranks_in_group)]
+
+ def get_ranks_in_group(self, group: ProcessGroup) -> List[int]:
+ """Get the ranks in the given process group. The process group must be created by this class.
+
+ Args:
+ group (ProcessGroup): The process group.
+
+ Returns:
+ List[int]: Ranks in the process group.
+ """
+ return list(self._group_to_ranks[group])
+
+ @staticmethod
+ def get_coords_along_axis(
+ base_coord: Tuple[int, ...], axis: int, indices_at_axis: List[int]
+ ) -> List[Tuple[int, ...]]:
+ """Get coordinates along the given axis.
+
+ Args:
+ base_coord (Tuple[int, ...]): Base coordinate which the coordinates along the axis are based on.
+ axis (int): Axis along which the coordinates are generated.
+ indices_at_axis (List[int]): Indices at the axis.
+
+ Returns:
+ List[Tuple[int, ...]]: Coordinates along the axis.
+ """
+ coords_in_group = []
+ for idx in indices_at_axis:
+ coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :])
+ return coords_in_group
+
+ def create_group_along_axis(
+ self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
+ ) -> ProcessGroup:
+ """Create all process groups along the given axis, and return the one which the current process belongs to.
+
+ Args:
+ axis (int): Axis along which the process groups are created.
+ indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
+ backend (Optional[str], optional): Backend of the process group. Defaults to None.
+
+ Returns:
+ ProcessGroup: The process group along the given axis which the current process belongs to.
+ """
+ indices_at_axis = indices_at_axis or list(range(self._shape[axis]))
+ reduced_shape = list(self._shape)
+ # the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
+ reduced_shape[axis] = 1
+ target_group = None
+ # use Cartesian product to generate all combinations of coordinates
+ for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
+ coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
+ ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
+ group = self.get_group(ranks_in_group, backend=backend)
+ if self._rank in ranks_in_group:
+ target_group = group
+ return target_group
+
+ def get_group_along_axis(
+ self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
+ ) -> ProcessGroup:
+ """Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
+
+ Args:
+ axis (int): Axis along which the process groups are created.
+ indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
+ backend (Optional[str], optional): Backend of the process group. Defaults to None.
+
+ Returns:
+ ProcessGroup: The process group along the given axis which the current process belongs to.
+ """
+ indices_at_axis = indices_at_axis or list(range(self._shape[axis]))
+ coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis)
+ ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
+ if ranks_in_group not in self._ranks_to_group:
+ # no need to cache it explicitly, since it will be cached in `create_group_along_axis`
+ return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
+ return self._ranks_to_group[ranks_in_group]
diff --git a/colossalai/communication/__init__.py b/colossalai/communication/__init__.py
deleted file mode 100644
index 220481b7af15bcada443ed7c9f8c91350a5f76b1..0000000000000000000000000000000000000000
--- a/colossalai/communication/__init__.py
+++ /dev/null
@@ -1,26 +0,0 @@
-from .collective import all_gather, reduce_scatter, all_reduce, broadcast, reduce
-from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward, send_backward,
- send_backward_recv_backward, send_forward_recv_backward, send_forward_backward_recv_forward_backward,
- recv_forward, recv_backward)
-from .ring import ring_forward
-from .utils import send_obj_meta, recv_obj_meta
-
-__all__ = [
- 'all_gather',
- 'reduce_scatter',
- 'all_reduce',
- 'broadcast',
- 'reduce',
- 'send_forward',
- 'send_forward_recv_forward',
- 'send_forward_backward_recv_forward_backward',
- 'send_backward',
- 'send_backward_recv_backward',
- 'send_backward_recv_forward',
- 'send_forward_recv_backward',
- 'recv_backward',
- 'recv_forward',
- 'ring_forward',
- 'send_obj_meta',
- 'recv_obj_meta',
-]
diff --git a/colossalai/communication/p2p.py b/colossalai/communication/p2p.py
deleted file mode 100644
index 0200cd3c6553dc8e2b3bbaa60ffb1d416c699370..0000000000000000000000000000000000000000
--- a/colossalai/communication/p2p.py
+++ /dev/null
@@ -1,405 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from typing import List, Tuple, Union
-import torch
-import torch.distributed as dist
-
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.utils import get_current_device
-from functools import reduce
-import operator
-from .utils import split_tensor_into_1d_equal_chunks, gather_split_1d_tensor
-
-TensorShape = Union[torch.Size, List[int], Tuple[int]]
-
-
-def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> Tuple[TensorShape, bool]:
- """get the exact tensor shape when communicating and return whether the tensor is a chunk
-
- Args:
- tensor_shape (:class:`torch.Size`): shape of tensor
- chunk_tensor (bool, optional): whether to chunk tensor, defaults to False
-
- Returns:
- Tuple[Union[:class:`torch.Size`, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor
- """
- if chunk_tensor:
- tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
- tensor_parallel_world_size = gpc.get_world_size(ParallelMode.TENSOR)
- if tensor_chunk_shape % tensor_parallel_world_size == 0:
- tensor_chunk_shape = tensor_chunk_shape // tensor_parallel_world_size
- else:
- tensor_chunk_shape = tensor_shape
- chunk_tensor = False
- else:
- tensor_chunk_shape = tensor_shape
- return tensor_chunk_shape, chunk_tensor
-
-
-def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors):
- if isinstance(recv_shapes, torch.Size):
- recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors)
- buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
- return buffer_recv, recv_split
- buffer_recv = []
- for recv_shape in recv_shapes:
- recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors)
- tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
- buffer_recv.append(tensor_recv)
- return buffer_recv, recv_split
-
-
-def process_object_to_send(object_send, scatter_gather_tensors):
- if isinstance(object_send, torch.Tensor):
- send_split = _get_tensor_shape(object_send.shape, scatter_gather_tensors)[1]
- if send_split:
- object_send = split_tensor_into_1d_equal_chunks(object_send)
- return object_send
-
- object_send_list = []
- for tensor_send in object_send:
- send_split = _get_tensor_shape(tensor_send.shape, scatter_gather_tensors)[1]
- if send_split:
- object_send_list.append(split_tensor_into_1d_equal_chunks(tensor_send))
- else:
- object_send_list.append(tensor_send)
- object_send = tuple(object_send_list)
-
- return object_send
-
-
-def filling_ops_queue(obj, comm_op, comm_rank, ops_queue):
- if isinstance(obj, torch.Tensor):
- op_to_add = dist.P2POp(comm_op, obj, comm_rank)
- ops_queue.append(op_to_add)
- else:
- for tensor_to_comm in obj:
- op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank)
- ops_queue.append(op_to_add)
-
-
-def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None,
- object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None,
- recv_prev: bool = False,
- recv_next: bool = False,
- recv_prev_shape: Union[torch.Size, List[torch.Size]] = None,
- recv_next_shape: Union[torch.Size, List[torch.Size]] = None,
- prev_rank: int = None,
- next_rank: int = None,
- dtype: torch.dtype = None,
- scatter_gather_tensors: bool = False) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
- """
- Adapted from megatron.p2p_communication.
- Communicate tensors between stages. Used as helper method in other
- communication methods that are used in pipeline schedule.
- Takes the following arguments:
- object_send_next (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to next rank (no tensor sent if
- set to None).
- object_send_prev (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to prev rank (no tensor sent if
- set to None).
- recv_prev (bool): boolean for whether tensor should be received from
- previous rank.
- recv_next (bool): boolean for whether tensor should be received from
- next rank.
- recv_prev_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received from the previous stage, defaults to None.
- recv_next_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received from the next stage, defaults to None.
- prev_rank (int): the rank of the previous pipeline stage, defaults to None,
- next_rank (int): the rank of the next pipeline stage, defaults to None,
- dtype (torch.dtype): data type of intermediate buffers, defaults to None
- scatter_gather_tensors (bool): whether to scatter and gather tensor between pipeline stages, defaults to False
-
- Returns:
- Tuple[Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]]: returns tensor_recv_prev, tensor_recv_next
- """
-
- # Create placeholder tensors for receive in forward and backward directions
- # if needed.
- tensor_recv_prev = None
- tensor_recv_next = None
-
- if recv_prev:
- assert recv_prev_shape is not None
- tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(recv_prev_shape, dtype,
- scatter_gather_tensors)
-
- if recv_next:
- assert recv_next_shape is not None
- tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(recv_next_shape, dtype,
- scatter_gather_tensors)
-
- if object_send_prev is not None or recv_prev:
- if prev_rank is None:
- prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
-
- if object_send_next is not None or recv_next:
- if next_rank is None:
- next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
-
- if object_send_prev is not None:
- object_send_prev = process_object_to_send(object_send_prev, scatter_gather_tensors)
-
- if object_send_next is not None:
- object_send_next = process_object_to_send(object_send_next, scatter_gather_tensors)
-
- ops = []
- if object_send_prev is not None:
- filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops)
-
- if tensor_recv_prev is not None:
- filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops)
-
- if tensor_recv_next is not None:
- filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops)
-
- if object_send_next is not None:
- filling_ops_queue(object_send_next, dist.isend, next_rank, ops)
-
- if len(ops) > 0:
- reqs = dist.batch_isend_irecv(ops)
- for req in reqs:
- req.wait()
- # To protect against race condition when using batch_isend_irecv().
- torch.cuda.synchronize()
-
- if recv_prev and recv_prev_split:
- if isinstance(tensor_recv_prev, torch.Tensor):
- tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
- else:
- for index in range(len(tensor_recv_prev)):
- tensor_recv_prev[index] = gather_split_1d_tensor(tensor_recv_prev[index]).view(
- recv_prev_shape[index]).requires_grad_()
-
- if recv_next and recv_next_split:
- if isinstance(tensor_recv_next, torch.Tensor):
- tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
- else:
- for index in range(len(tensor_recv_next)):
- tensor_recv_next[index] = gather_split_1d_tensor(tensor_recv_next[index]).view(
- recv_next_shape[index]).requires_grad_()
-
- return tensor_recv_prev, tensor_recv_next
-
-
-def recv_forward(input_tensor_shape,
- prev_rank=None,
- dtype=torch.float,
- scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
- """Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
-
- Args:
- input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
- prev_rank (int, optional): The rank of the source of the tensor.
-
- Returns:
- Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor or input tensor list.
- """
- if gpc.is_pipeline_first_stage():
- input_tensor = None
- else:
- input_tensor, _ = _communicate(recv_prev=True,
- recv_prev_shape=input_tensor_shape,
- prev_rank=prev_rank,
- dtype=dtype,
- scatter_gather_tensors=scatter_gather_tensors)
- return input_tensor
-
-
-def recv_backward(output_grad_shape,
- next_rank=None,
- dtype=torch.float,
- scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
- """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
-
- Args:
- output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
- next_rank (int, optional): The rank of the source of the tensor.
-
- Returns:
- Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor or gradident tensor list.
- """
- if gpc.is_pipeline_last_stage():
- output_tensor_grad = None
- else:
- _, output_tensor_grad = _communicate(recv_next=True,
- recv_next_shape=output_grad_shape,
- next_rank=next_rank,
- dtype=dtype,
- scatter_gather_tensors=scatter_gather_tensors)
- return output_tensor_grad
-
-
-def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False) -> None:
- """Sends the input tensor to the next stage in pipeline.
-
- Args:
- output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
- next_rank (int, optional): The rank of the recipient of the tensor.
- """
- if not gpc.is_pipeline_last_stage():
- _communicate(object_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
-
-
-def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False) -> None:
- """Sends the gradient tensor to the previous stage in pipeline.
-
- Args:
- input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent
- prev_rank (int, optional): The rank of the recipient of the tensor
- """
- if not gpc.is_pipeline_first_stage():
- _communicate(object_send_prev=input_tensor_grad,
- prev_rank=prev_rank,
- scatter_gather_tensors=scatter_gather_tensors)
-
-
-def send_forward_recv_backward(output_tensor,
- output_grad_shape,
- recv_next=True,
- next_rank=None,
- dtype=torch.float,
- scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
- """Batched communication operation. Sends the input tensor to the
- next stage in pipeline, while receives the gradient tensor from the
- next stage in pipeline as the input gradient tensor of this stage.
-
- Args:
- output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
- output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
-
- Returns:
- Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
- """
- if gpc.is_pipeline_last_stage():
- output_tensor_grad = None
- else:
- _, output_tensor_grad = _communicate(object_send_next=output_tensor,
- recv_next=recv_next,
- recv_next_shape=output_grad_shape,
- next_rank=next_rank,
- dtype=dtype,
- scatter_gather_tensors=scatter_gather_tensors)
- return output_tensor_grad
-
-
-def send_backward_recv_forward(input_tensor_grad,
- input_tensor_shape,
- recv_prev=True,
- prev_rank=None,
- dtype=torch.float,
- scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
- """Batched communication operation. Sends the gradient tensor to the
- previous stage in pipeline, while receives the output tensor from the
- previous stage in pipeline as the input of this stage.
-
- Args:
- input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
- input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
-
- Returns:
- Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
- """
- if gpc.is_pipeline_first_stage():
- input_tensor = None
- else:
- input_tensor, _ = _communicate(object_send_prev=input_tensor_grad,
- recv_prev=recv_prev,
- recv_prev_shape=input_tensor_shape,
- prev_rank=prev_rank,
- dtype=dtype,
- scatter_gather_tensors=scatter_gather_tensors)
- return input_tensor
-
-
-def send_forward_recv_forward(output_tensor,
- input_tensor_shape,
- recv_prev=True,
- prev_rank=None,
- next_rank=None,
- dtype=torch.float,
- scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
- """Batched communication operation. Sends the input tensor to the
- next stage in pipeline, while receives the output tensor from the
- previous stage in pipeline as the input of this stage.
-
- Args:
- output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
- input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
-
- Returns:
- Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
- """
- input_tensor, _ = _communicate(object_send_next=output_tensor,
- recv_prev=recv_prev,
- recv_prev_shape=input_tensor_shape,
- prev_rank=prev_rank,
- next_rank=next_rank,
- dtype=dtype,
- scatter_gather_tensors=scatter_gather_tensors)
- return input_tensor
-
-
-def send_backward_recv_backward(input_tensor_grad,
- output_grad_shape,
- recv_next=True,
- prev_rank=None,
- next_rank=None,
- dtype=torch.float,
- scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
- """Batched communication operation. Sends the gradient tensor to the
- previous stage in pipeline, while receives the gradient tensor from the
- next member in pipeline as the input of this stage.
-
- Args:
- input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
- output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
-
- Returns:
- Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
- """
- _, output_tensor_grad = _communicate(object_send_prev=input_tensor_grad,
- recv_next=recv_next,
- recv_next_shape=output_grad_shape,
- prev_rank=prev_rank,
- next_rank=next_rank,
- dtype=dtype,
- scatter_gather_tensors=scatter_gather_tensors)
- return output_tensor_grad
-
-
-def send_forward_backward_recv_forward_backward(
- output_tensor,
- input_tensor_grad,
- input_tensor_shape,
- output_grad_shape,
- recv_prev=True,
- recv_next=True,
- prev_rank=None,
- next_rank=None,
- dtype=torch.float,
- scatter_gather_tensors=False) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
- """Batched communication operation. Sends the input tensor to the next stage in pipeline and
- the gradient tensor to the previous stage, while receives the input gradient tensor from the
- next stage and the input tensor from the previous stage.
-
- Args:
- output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the next.
- input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the previous.
- input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received from the previous.
- output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received from the next.
-
- Returns:
- Tuple(Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]], Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): (the input tensor, the input gradient tensor)
- """
- input_tensor, output_tensor_grad = _communicate(object_send_next=output_tensor,
- object_send_prev=input_tensor_grad,
- recv_prev=recv_prev,
- recv_next=recv_next,
- recv_prev_shape=input_tensor_shape,
- recv_next_shape=output_grad_shape,
- prev_rank=prev_rank,
- next_rank=next_rank,
- dtype=dtype,
- scatter_gather_tensors=scatter_gather_tensors)
- return input_tensor, output_tensor_grad
diff --git a/colossalai/communication/ring.py b/colossalai/communication/ring.py
deleted file mode 100644
index aece7574b7c41cac3b16cd5891b1e26d0ede9c36..0000000000000000000000000000000000000000
--- a/colossalai/communication/ring.py
+++ /dev/null
@@ -1,56 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import torch
-
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.utils import get_current_device, synchronize
-
-
-def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> torch.Tensor:
- """Sends a tensor to the next member and receives a tensor from the previous member.
- This function returns the received tensor from the previous member.
-
- Args:
- tensor_send_next (:class:`torch.Tensor`): Tensor sent to next member
- parallel_mode (ParallelMode): Parallel group mode used in this communication
-
- Returns:
- :class:`torch.Tensor`: The tensor received from the previous.
-
- Note:
- The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
- in `parallel_mode `_.
- """
- buffer_shape = tensor_send_next.size()
-
- ops = []
- current_rank = gpc.get_global_rank()
-
- tensor_recv_prev = torch.empty(buffer_shape,
- requires_grad=True,
- device=get_current_device(),
- dtype=tensor_send_next.dtype)
-
- # send to next rank
- send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next,
- gpc.get_next_global_rank(parallel_mode))
- ops.append(send_next_op)
-
- # receive from prev rank
- recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev,
- gpc.get_prev_global_rank(parallel_mode))
- ops.append(recv_prev_op)
-
- if current_rank % 2 == 0:
- ops = ops[::-1]
-
- reqs = torch.distributed.batch_isend_irecv(ops)
- for req in reqs:
- req.wait()
-
- # To protect against race condition when using batch_isend_irecv().
- synchronize()
-
- return tensor_recv_prev
diff --git a/colossalai/communication/utils.py b/colossalai/communication/utils.py
deleted file mode 100644
index ef9eceea847dd3d6cb036e87e369529dcbe0db41..0000000000000000000000000000000000000000
--- a/colossalai/communication/utils.py
+++ /dev/null
@@ -1,126 +0,0 @@
-import torch
-import torch.distributed as dist
-
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.utils import get_current_device
-from typing import Union, List, Tuple
-
-TensorShape = Union[torch.Size, List[int], Tuple[int]]
-
-
-def send_meta_helper(obj, next_rank, tensor_kwargs):
- send_shape = torch.tensor(obj.size(), **tensor_kwargs)
- send_ndims = torch.tensor(len(obj.size()), **tensor_kwargs)
- dist.send(send_ndims, next_rank)
- dist.send(send_shape, next_rank)
-
-
-def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
- """Sends obj meta information before sending a specific obj.
- Since the recipient must know the shape of the obj in p2p communications,
- meta information of the obj should be sent before communications. This function
- synchronizes with :func:`recv_obj_meta`.
-
- Args:
- obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): obj to be sent.
- need_meta (bool, optional): If False, meta information won't be sent.
- next_rank (int): The rank of the next member in pipeline parallel group.
-
- Returns:
- bool: False
- """
- if need_meta:
- if next_rank is None:
- next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
-
- tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
- if isinstance(obj, torch.Tensor):
- send_obj_nums = torch.tensor(1, **tensor_kwargs)
- dist.send(send_obj_nums, next_rank)
- send_meta_helper(obj, next_rank, tensor_kwargs)
- else:
- send_obj_nums = torch.tensor(len(obj), **tensor_kwargs)
- dist.send(send_obj_nums, next_rank)
- for tensor_to_send in obj:
- send_meta_helper(tensor_to_send, next_rank, tensor_kwargs)
-
- return False
-
-
-def recv_meta_helper(prev_rank, tensor_kwargs):
- recv_ndims = torch.empty((), **tensor_kwargs)
- dist.recv(recv_ndims, prev_rank)
- recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
- dist.recv(recv_shape, prev_rank)
- return recv_shape
-
-
-def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size:
- """Receives obj meta information before receiving a specific obj.
- Since the recipient must know the shape of the obj in p2p communications,
- meta information of the obj should be received before communications. This function
- synchronizes with :func:`send_obj_meta`.
-
- Args:
- obj_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the obj to be received.
- prev_rank (int): The rank of the source of the obj.
-
- Returns:
- Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received.
- """
- if obj_shape is None:
- if prev_rank is None:
- prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
-
- tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
- recv_obj_nums = torch.empty((), **tensor_kwargs)
- dist.recv(recv_obj_nums, prev_rank)
- if recv_obj_nums.item() == 1:
- recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
- obj_shape = torch.Size(recv_shape)
- else:
- obj_shape = []
- for i in range(recv_obj_nums.item()):
- recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
- obj_shape.append(torch.Size(recv_shape))
-
- return obj_shape
-
-
-def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor:
- """Break a tensor into equal 1D chunks.
-
- Args:
- tensor (:class:`torch.Tensor`): Tensor to be split before communication.
- new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor.
-
- Returns:
- :class:`torch.Tensor`: The split tensor
- """
- partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.PARALLEL_1D)
- start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D)
- end_index = start_index + partition_size
- if new_buffer:
- data = torch.empty(partition_size, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
- data.copy_(tensor.view(-1)[start_index:end_index])
- else:
- data = tensor.view(-1)[start_index:end_index]
- return data
-
-
-def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor:
- """Opposite of above function, gather values from model parallel ranks.
-
- Args:
- tensor (:class:`torch.Tensor`): Tensor to be gathered after communication.
- Returns:
- :class:`torch.Tensor`: The gathered tensor.
- """
- world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
- numel = torch.numel(tensor)
- numel_gathered = world_size * numel
- gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
- chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)]
- dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D))
- return gathered
diff --git a/colossalai/constants.py b/colossalai/constants.py
deleted file mode 100644
index 6cf9085f9fbb63ea18d2712f99c08f24b539245d..0000000000000000000000000000000000000000
--- a/colossalai/constants.py
+++ /dev/null
@@ -1,32 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-ALLOWED_MODES = [None, '1d', '2d', '2.5d', '3d', 'sequence']
-TENSOR_PARALLEL_MODE = 'tensor_parallel_mode'
-
-# initializer
-INITIALIZER_MAPPING = {
- 'data': 'Initializer_Data',
- 'tensor': 'Initializer_Tensor',
- 'pipeline': 'Initializer_Pipeline',
- 'embedding': 'Initializer_Embedding',
- '1d': 'Initializer_1D',
- '2d': 'Initializer_2D',
- '2.5d': 'Initializer_2p5D',
- '3d': 'Initializer_3D',
- 'sequence': 'Initializer_Sequence',
- 'model': 'Initializer_Model',
- 'moe': 'Initializer_Moe'
-}
-
-# 3D parallelism groups
-INPUT_GROUP_3D = 'input_group_3d'
-WEIGHT_GROUP_3D = 'weight_group_3d'
-OUTPUT_GROUP_3D = 'output_group_3d'
-INPUT_X_WEIGHT_3D = 'input_x_weight_group_3d'
-OUTPUT_X_WEIGHT_3D = 'output_x_weight_group_3d'
-
-# Attributes of tensor parallel parameters
-IS_TENSOR_PARALLEL = 'is_tensor_parallel'
-NUM_PARTITIONS = 'num_partitions'
-TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS]
diff --git a/colossalai/context/__init__.py b/colossalai/context/__init__.py
index 50178b5fa850777f8455798cc6ab9d7254c5a9fe..ab57301bb9108a72ebe59cbdbba07caa35f46bb1 100644
--- a/colossalai/context/__init__.py
+++ b/colossalai/context/__init__.py
@@ -1,6 +1,8 @@
from .config import Config, ConfigException
-from .parallel_context import ParallelContext
-from .parallel_mode import ParallelMode
-from .moe_context import MOE_CONTEXT
-from .process_group_initializer import *
-from .random import *
+
+# from .moe_context import MOE_CONTEXT
+
+__all__ = [
+ "Config",
+ "ConfigException",
+]
diff --git a/colossalai/context/config.py b/colossalai/context/config.py
index 8903707708df96eac7a0a70343e37e984e6fabed..05a2e4bf044a0561012f62d521cf1fb5c8500a9b 100644
--- a/colossalai/context/config.py
+++ b/colossalai/context/config.py
@@ -5,6 +5,7 @@ import inspect
import sys
from importlib.machinery import SourceFileLoader
from pathlib import Path
+
from colossalai.logging import get_dist_logger
@@ -41,7 +42,7 @@ class Config(dict):
self.__setattr__(key, value)
def update(self, config):
- assert isinstance(config, (Config, dict)), 'can only update dictionary or Config objects.'
+ assert isinstance(config, (Config, dict)), "can only update dictionary or Config objects."
for k, v in config.items():
self._add_item(k, v)
return self
@@ -66,11 +67,11 @@ class Config(dict):
elif isinstance(filename, Path):
filepath = filename.absolute()
- assert filepath.exists(), f'{filename} is not found, please check your configuration path'
+ assert filepath.exists(), f"{filename} is not found, please check your configuration path"
# check extension
extension = filepath.suffix
- assert extension == '.py', 'only .py files are supported'
+ assert extension == ".py", "only .py files are supported"
# import the config as module
remove_path = False
@@ -86,13 +87,13 @@ class Config(dict):
config = Config()
for k, v in module.__dict__.items():
- if k.startswith('__') or inspect.ismodule(v) or inspect.isclass(v):
+ if k.startswith("__") or inspect.ismodule(v) or inspect.isclass(v):
continue
else:
config._add_item(k, v)
logger = get_dist_logger()
- logger.debug('variables which starts with __, is a module or class declaration are omitted in config file')
+ logger.debug("variables which starts with __, is a module or class declaration are omitted in config file")
# remove module
del sys.modules[module_name]
diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py
index b41f4072a4052113b3e3a79a20c0278b9fed8295..066dfc7222e116345ae9c3af11438f0fef46db04 100644
--- a/colossalai/context/moe_context.py
+++ b/colossalai/context/moe_context.py
@@ -3,21 +3,19 @@ from typing import Tuple
import torch
import torch.distributed as dist
-from colossalai.context.parallel_mode import ParallelMode
from colossalai.context.singleton_meta import SingletonMeta
-from colossalai.tensor import ProcessGroup
+from colossalai.legacy.tensor import ProcessGroup
def _check_sanity():
- from colossalai.core import global_context as gpc
+ from colossalai.legacy.core import global_context as gpc
+
if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1:
- raise NotImplementedError("Moe is not compatible with tensor or "
- "pipeline parallel at present.")
+ raise NotImplementedError("Moe is not compatible with tensor or " "pipeline parallel at present.")
class MoeParallelInfo:
- """Moe parallelism information, storing parallel sizes and groups.
- """
+ """Moe parallelism information, storing parallel sizes and groups."""
def __init__(self, ep_size: int, dp_size: int):
_check_sanity()
@@ -61,10 +59,12 @@ class MoeContext(metaclass=SingletonMeta):
self.world_size = dist.get_world_size()
- from colossalai.core import global_context as gpc
- self.max_ep_size = gpc.config.get('max_ep_size', self.world_size)
- assert self.world_size % self.max_ep_size == 0, \
- "Maximum expert parallel size must be a factor of the number of GPUs"
+ from colossalai.legacy.core import global_context as gpc
+
+ self.max_ep_size = gpc.config.get("max_ep_size", self.world_size)
+ assert (
+ self.world_size % self.max_ep_size == 0
+ ), "Maximum expert parallel size must be a factor of the number of GPUs"
self.min_dp_size = self.world_size // self.max_ep_size
# Enabling kernel optimization may raise error in some cases
@@ -72,6 +72,7 @@ class MoeContext(metaclass=SingletonMeta):
self.use_kernel_optim = use_kernel_optim
from .random import moe_set_seed
+
moe_set_seed(seed)
self.has_setup = True
@@ -89,11 +90,13 @@ class MoeContext(metaclass=SingletonMeta):
number of local experts, the MoeParallelInfo of the current ep_size
"""
- gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
- lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
+ gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
+ lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
- assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \
- " is not a multiple of ep size or vice versa."
+ assert gt_flag or lt_flag, (
+ "Automatic experts placement dose not not support expert number"
+ " is not a multiple of ep size or vice versa."
+ )
# If the number of experts is greater than maximum expert parallel size. a.k.a ep_size,
# there are multiple experts in each GPU and each GPU has different experts
diff --git a/colossalai/context/parallel_mode.py b/colossalai/context/parallel_mode.py
deleted file mode 100644
index 1cf6fa53dc1e5c31fbaf1c9140e0915419af704c..0000000000000000000000000000000000000000
--- a/colossalai/context/parallel_mode.py
+++ /dev/null
@@ -1,49 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from enum import Enum
-
-
-# parallel modes
-class ParallelMode(Enum):
- """This is an enumeration class containing all possible parallel modes.
- """
-
- GLOBAL = 'global'
-
- # common parallel
- DATA = 'data'
-
- # model parallel - containing tensor and pipeline parallel groups
- # this is added to facilitate amp and grad clipping in hybrid parallel
- MODEL = 'model'
-
- # pipeline parallel
- PIPELINE = 'pipe'
-
- # containing all ranks in tensor parallel
- TENSOR = 'tensor'
-
- # sequence parallel
- SEQUENCE = 'sequence'
- SEQUENCE_DP = 'sequence_dp'
-
- # 1D Parallel
- PARALLEL_1D = '1d'
-
- # 2D parallel
- PARALLEL_2D_ROW = '2d_row'
- PARALLEL_2D_COL = '2d_col'
-
- # 3D parallel
- PARALLEL_3D_INPUT = '3d_input'
- PARALLEL_3D_WEIGHT = '3d_weight'
- PARALLEL_3D_OUTPUT = '3d_output'
- PARALLEL_3D_INPUT_X_WEIGHT = "3d_input_x_weight"
- PARALLEL_3D_OUTPUT_X_WEIGHT = "3d_output_x_weight"
-
- # 2.5D parallel
- PARALLEL_2P5D_ROW = '2p5d_row'
- PARALLEL_2P5D_COL = '2p5d_col'
- PARALLEL_2P5D_DEP = '2p5d_dep'
- PARALLEL_2P5D_XZ = '2p5d_xz'
diff --git a/colossalai/context/process_group_initializer/__init__.py b/colossalai/context/process_group_initializer/__init__.py
deleted file mode 100644
index d3937a9474376f0ecb7af612121bb4c3e5f5a497..0000000000000000000000000000000000000000
--- a/colossalai/context/process_group_initializer/__init__.py
+++ /dev/null
@@ -1,15 +0,0 @@
-from .initializer_1d import Initializer_1D
-from .initializer_2d import Initializer_2D
-from .initializer_2p5d import Initializer_2p5D
-from .initializer_3d import Initializer_3D
-from .initializer_data import Initializer_Data
-from .initializer_pipeline import Initializer_Pipeline
-from .initializer_sequence import Initializer_Sequence
-from .initializer_tensor import Initializer_Tensor
-from .initializer_model import Initializer_Model
-from .process_group_initializer import ProcessGroupInitializer
-
-__all__ = [
- 'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline', 'Initializer_Data', 'Initializer_2p5D',
- 'Initializer_2D', 'Initializer_3D', 'Initializer_1D', 'ProcessGroupInitializer', 'Initializer_Model'
-]
diff --git a/colossalai/context/process_group_initializer/initializer_pipeline.py b/colossalai/context/process_group_initializer/initializer_pipeline.py
deleted file mode 100644
index 0ddb52f63e22f29aff9920d5cdd2aba1748e1eb6..0000000000000000000000000000000000000000
--- a/colossalai/context/process_group_initializer/initializer_pipeline.py
+++ /dev/null
@@ -1,56 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from torch import distributed as dist
-
-from colossalai.registry import DIST_GROUP_INITIALIZER
-
-from ..parallel_mode import ParallelMode
-from .process_group_initializer import ProcessGroupInitializer
-
-
-@DIST_GROUP_INITIALIZER.register_module
-class Initializer_Pipeline(ProcessGroupInitializer):
- """A ProcessGroupInitializer for pipeline parallelism.
-
- Args:
- rank (int): The rank of current process
- world_size (int): Size of whole communication world
- config (Config): Running configuration
- data_parallel_size (int): Size of data parallel
- pipeline_parallel_size (int): Size of pipeline parallel
- tensor_parallel_size (int): Size of tensor parallel
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.data_group_size = self.world_size // self.data_parallel_size
- self.pipeline_stage_size = self.data_group_size // self.pipeline_parallel_size
-
- def init_dist_group(self):
- """Initialize pipeline parallel groups, and assign local_ranks and groups to each gpu.
-
- Returns:
- List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]:
- A Pipeline parallelism's information in list of tuples.
- """
- dist_settings = list()
- for i in range(self.data_parallel_size):
- for j in range(self.pipeline_stage_size):
- pipe_ranks = list(
- range(i * self.data_group_size + j, (i + 1) * self.data_group_size, self.pipeline_stage_size))
- pipe_group_size = len(pipe_ranks)
- pipe_group = dist.new_group(pipe_ranks)
- group_cpu = dist.new_group(pipe_ranks, backend='gloo') if dist.get_backend() != 'gloo' else pipe_group
-
- if self.rank in pipe_ranks:
- local_rank = pipe_ranks.index(self.rank)
- group_world_size = pipe_group_size
- process_group = pipe_group
- cpu_group = group_cpu
- ranks_in_group = pipe_ranks
- dist_settings.append(
- tuple((local_rank, group_world_size, process_group, cpu_group, ranks_in_group,
- ParallelMode.PIPELINE)))
-
- return dist_settings
diff --git a/colossalai/context/random/__init__.py b/colossalai/context/random/__init__.py
deleted file mode 100644
index d64b993257c1574706ee5028224692b4e666fc19..0000000000000000000000000000000000000000
--- a/colossalai/context/random/__init__.py
+++ /dev/null
@@ -1,18 +0,0 @@
-from ._helper import (
- add_seed,
- get_current_mode,
- get_seeds,
- get_states,
- moe_set_seed,
- reset_seeds,
- seed,
- set_mode,
- set_seed_states,
- sync_states,
- with_seed,
-)
-
-__all__ = [
- 'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds', 'get_states', 'get_current_mode', 'set_seed_states',
- 'sync_states', 'moe_set_seed', 'reset_seeds'
-]
diff --git a/colossalai/context/singleton_meta.py b/colossalai/context/singleton_meta.py
index 8ca335119d52ad2a212b1e0c578202b2fc6bb60f..3088b0dffaace5da66de7525e8598a58e545b80d 100644
--- a/colossalai/context/singleton_meta.py
+++ b/colossalai/context/singleton_meta.py
@@ -16,6 +16,7 @@ class SingletonMeta(type):
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
else:
- assert len(args) == 0 and len(
- kwargs) == 0, f'{cls.__name__} is a singleton class and a instance has been created.'
+ assert (
+ len(args) == 0 and len(kwargs) == 0
+ ), f"{cls.__name__} is a singleton class and a instance has been created."
return cls._instances[cls]
diff --git a/colossalai/core.py b/colossalai/core.py
deleted file mode 100644
index 153247bbed9c65db0b2255247137fa9a64a693fa..0000000000000000000000000000000000000000
--- a/colossalai/core.py
+++ /dev/null
@@ -1,6 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from colossalai.context.parallel_context import global_context
-
-__all__ = ['global_context']
\ No newline at end of file
diff --git a/colossalai/device/__init__.py b/colossalai/device/__init__.py
index 689189998c3f6490145ba2648c522570c6f40b4c..34a7d2526fdad1fa1e3622d7b5f0a6bf419d9b50 100644
--- a/colossalai/device/__init__.py
+++ b/colossalai/device/__init__.py
@@ -1,4 +1,4 @@
from .alpha_beta_profiler import AlphaBetaProfiler
from .calc_pipeline_strategy import alpa_dp
-__all__ = ['AlphaBetaProfiler', 'alpa_dp']
+__all__ = ["AlphaBetaProfiler", "alpa_dp"]
diff --git a/colossalai/device/alpha_beta_profiler.py b/colossalai/device/alpha_beta_profiler.py
index af2b10928c6f2f99a429aaa413d527d77d52faf0..88520b2a14d08cb379e255485dfb9a228bf0da63 100644
--- a/colossalai/device/alpha_beta_profiler.py
+++ b/colossalai/device/alpha_beta_profiler.py
@@ -13,7 +13,7 @@ FRAMEWORK_LATENCY = 0
class AlphaBetaProfiler:
- '''
+ """
Profile alpha and beta value for a given device list.
Usage:
@@ -27,17 +27,19 @@ class AlphaBetaProfiler:
(1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12),
(1, 0): (1.9641406834125518e-05, 4.74049549614719e-12), (4, 0): (1.9506998360157013e-05, 6.97421973297474e-11), (5, 0): (2.293858677148819e-05, 7.129930361393644e-11),
(4, 1): (1.9010603427886962e-05, 7.077968863788975e-11), (5, 1): (1.9807778298854827e-05, 6.928845708992215e-11), (5, 4): (1.8681809306144713e-05, 4.7522367291330524e-12)}
- '''
-
- def __init__(self,
- physical_devices: List[int],
- alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None,
- ctype: str = 'a',
- warmup: int = 5,
- repeat: int = 25,
- latency_iters: int = 5,
- homogeneous_tolerance: float = 0.1):
- '''
+ """
+
+ def __init__(
+ self,
+ physical_devices: List[int],
+ alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None,
+ ctype: str = "a",
+ warmup: int = 5,
+ repeat: int = 25,
+ latency_iters: int = 5,
+ homogeneous_tolerance: float = 0.1,
+ ):
+ """
Args:
physical_devices: A list of device id, each element inside it is the global rank of that device.
alpha_beta_dict: A dict which maps a process group to alpha-beta value pairs.
@@ -45,7 +47,7 @@ class AlphaBetaProfiler:
warmup: Number of warmup iterations.
repeat: Number of iterations to measure.
latency_iters: Number of iterations to measure latency.
- '''
+ """
self.physical_devices = physical_devices
self.ctype = ctype
self.world_size = len(physical_devices)
@@ -123,7 +125,7 @@ class AlphaBetaProfiler:
return (None, None)
def profile_latency(self, process_group, pg_handler):
- '''
+ """
This function is used to profile the latency of the given process group with a series of bytes.
Args:
@@ -132,7 +134,7 @@ class AlphaBetaProfiler:
Returns:
latency: None if the latency is not measured, otherwise the median of the latency_list.
- '''
+ """
latency_list = []
for i in range(self.latency_iters):
nbytes = int(BYTE << i)
@@ -148,26 +150,26 @@ class AlphaBetaProfiler:
return latency
def profile_bandwidth(self, process_group, pg_handler, maxbytes=(1 * GB)):
- '''
+ """
This function is used to profile the bandwidth of the given process group.
Args:
process_group: A tuple of global rank of the process group.
pg_handler: The handler of the process group.
- '''
+ """
(_, bandwidth) = self._profile(process_group, pg_handler, maxbytes)
return bandwidth
def profile_ab(self):
- '''
+ """
This method is used to profiling the alpha and beta value for a given device list.
Returns:
alpha_beta_dict: A dict which maps process group to its alpha and beta value.
- '''
+ """
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = {}
rank = dist.get_rank()
- global_pg_handler = dist.new_group(self.physical_devices)
+ dist.new_group(self.physical_devices)
def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup):
assert rank in process_group
@@ -197,7 +199,7 @@ class AlphaBetaProfiler:
dist.broadcast_object_list(broadcast_list, src=process_group[0])
alpha_beta_dict[process_group] = tuple(broadcast_list)
- # add symmetry pair to the apha_beta_dict
+ # add symmetry pair to the alpha_beta_dict
symmetry_ab_dict = {}
for process_group, alpha_beta_pair in alpha_beta_dict.items():
symmetry_process_group = (process_group[1], process_group[0])
@@ -208,7 +210,7 @@ class AlphaBetaProfiler:
return alpha_beta_dict
def search_best_logical_mesh(self):
- '''
+ """
This method is used to search the best logical mesh for the given device list.
The best logical mesh is searched in following steps:
@@ -232,19 +234,19 @@ class AlphaBetaProfiler:
>>> best_logical_mesh = profiler.search_best_logical_mesh()
>>> print(best_logical_mesh)
[[0, 1], [2, 3]]
- '''
+ """
def _power_of_two(integer):
return integer & (integer - 1) == 0
def _detect_homogeneous_device(alpha_beta_dict):
- '''
+ """
This function is used to detect whether the devices in the alpha_beta_dict are homogeneous.
Note: we assume that the devices in the alpha_beta_dict are homogeneous if the beta value
of the devices are in range of [(1 - self.homogeneous_tolerance), (1 + self.homogeneous_tolerance)]
* base_beta.
- '''
+ """
homogeneous_device_dict: Dict[float, List[Tuple[int]]] = {}
for process_group, (_, beta) in alpha_beta_dict.items():
if homogeneous_device_dict is None:
@@ -254,7 +256,8 @@ class AlphaBetaProfiler:
match_beta = None
for beta_value in homogeneous_device_dict.keys():
if beta <= beta_value * (1 + self.homogeneous_tolerance) and beta >= beta_value * (
- 1 - self.homogeneous_tolerance):
+ 1 - self.homogeneous_tolerance
+ ):
match_beta = beta_value
break
@@ -267,9 +270,9 @@ class AlphaBetaProfiler:
return homogeneous_device_dict
def _check_contain_all_devices(homogeneous_group: List[Tuple[int]]):
- '''
+ """
This function is used to check whether the homogeneous_group contains all physical devices.
- '''
+ """
flatten_mesh = []
for process_group in homogeneous_group:
flatten_mesh.extend(process_group)
@@ -277,9 +280,9 @@ class AlphaBetaProfiler:
return len(non_duplicated_flatten_mesh) == len(self.physical_devices)
def _construct_largest_ring(homogeneous_group: List[Tuple[int]]):
- '''
+ """
This function is used to construct the largest ring in the homogeneous_group for each rank.
- '''
+ """
# Construct the ring
ring = []
ranks_in_ring = []
@@ -300,7 +303,9 @@ class AlphaBetaProfiler:
check_rank = check_rank_list.pop()
for process_group in homogeneous_group:
if check_rank in process_group:
- rank_to_append = process_group[0] if process_group[1] == check_rank else process_group[1]
+ rank_to_append = (
+ process_group[0] if process_group[1] == check_rank else process_group[1]
+ )
if rank_to_append not in ring_for_rank:
stable_status = False
rank_to_check_list.append(rank_to_append)
@@ -314,7 +319,7 @@ class AlphaBetaProfiler:
assert _power_of_two(self.world_size)
power_of_two = int(math.log2(self.world_size))
median = power_of_two // 2
- balanced_logical_mesh_shape = (2**median, 2**(power_of_two - median))
+ balanced_logical_mesh_shape = (2**median, 2 ** (power_of_two - median))
row_size, column_size = balanced_logical_mesh_shape[0], balanced_logical_mesh_shape[1]
balanced_logical_mesh = []
for row_index in range(row_size):
@@ -348,7 +353,7 @@ class AlphaBetaProfiler:
return best_logical_mesh
def extract_alpha_beta_for_device_mesh(self):
- '''
+ """
Extract the mesh_alpha list and mesh_beta list based on the
best logical mesh, which will be used to initialize the device mesh.
@@ -360,7 +365,7 @@ class AlphaBetaProfiler:
[2.5917552411556242e-05, 0.00010312341153621673]
>>> print(mesh_beta)
[5.875573704655635e-11, 4.7361584445959614e-12]
- '''
+ """
best_logical_mesh = self.search_best_logical_mesh()
first_axis = [row[0] for row in best_logical_mesh]
@@ -381,7 +386,7 @@ class AlphaBetaProfiler:
first_latency, first_bandwidth = _extract_alpha_beta(first_axis, first_axis_process_group)
second_latency, second_bandwidth = _extract_alpha_beta(second_axis, second_axis_process_group)
mesh_alpha = [first_latency, second_latency]
- # The beta values have been enlarged by 1e10 times temporarilly because the computation cost
+ # The beta values have been enlarged by 1e10 times temporarily because the computation cost
# is still estimated in the unit of TFLOPs instead of time. We will remove this factor in future.
mesh_beta = [1e10 / first_bandwidth, 1e10 / second_bandwidth]
diff --git a/colossalai/device/calc_pipeline_strategy.py b/colossalai/device/calc_pipeline_strategy.py
index 4ab72dfe60f0c73f0e4f5186ed54205b68513bc0..72d432701ada2f9c1c8c541404cab10a07ea22b7 100644
--- a/colossalai/device/calc_pipeline_strategy.py
+++ b/colossalai/device/calc_pipeline_strategy.py
@@ -10,8 +10,10 @@ def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"):
while i <= num_devices_per_host:
i *= 2
p += 1
- assert pow(2, p) == num_devices_per_host, ("Only supports the cases where num_devices_per_host is power of two, "
- f"while now num_devices_per_host = {num_devices_per_host}")
+ assert pow(2, p) == num_devices_per_host, (
+ "Only supports the cases where num_devices_per_host is power of two, "
+ f"while now num_devices_per_host = {num_devices_per_host}"
+ )
if mode == "alpa":
for i in range(p + 1):
submesh_choices.append((1, pow(2, i)))
@@ -24,18 +26,19 @@ def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"):
return submesh_choices
-def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost,
- best_configs):
+def alpa_dp_impl(
+ num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost, best_configs
+):
"""Implementation of Alpa DP for pipeline strategy
- Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf
+ Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf
- Arguments:
- num_layers: K
- num_devices: N*M
- num_microbatches: B
- submesh_choices: List[(n_i,m_i)]
- compute_cost: t_intra
- """
+ Arguments:
+ num_layers: K
+ num_devices: N*M
+ num_microbatches: B
+ submesh_choices: List[(n_i,m_i)]
+ compute_cost: t_intra
+ """
# For f, layer ID start from 0
# f[#pipeline stages, layer id that is currently being considered, number of devices used]
f = np.full((num_layers + 1, num_layers + 1, num_devices + 1), np.inf, dtype=np.float32)
@@ -54,7 +57,7 @@ def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, com
for i in range(num_layers, k, -1):
stage_cost = compute_cost[k, i, m]
new_cost = f[s - 1, k, d - n_submesh_devices] + stage_cost
- if (stage_cost <= max_stage_cost and new_cost < f[s, k, d]):
+ if stage_cost <= max_stage_cost and new_cost < f[s, k, d]:
f[s, k, d] = new_cost
f_stage_max[s, k, d] = max(stage_cost, f_stage_max[s - 1, i, d - n_submesh_devices])
f_argmin[s, k, d] = (i, m, best_configs[k, i, m])
@@ -75,34 +78,34 @@ def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, com
res = []
while current_s > 0 and current_layer < num_layers and current_devices > 0:
- next_start_layer, submesh_choice, autosharding_choice = (f_argmin[current_s, current_layer, current_devices])
+ next_start_layer, submesh_choice, autosharding_choice = f_argmin[current_s, current_layer, current_devices]
assert next_start_layer != -1 and current_devices != -1
res.append(((current_layer, next_start_layer), submesh_choice, autosharding_choice))
current_s -= 1
current_layer = next_start_layer
current_devices -= np.prod(np.array(submesh_choices[submesh_choice]))
- assert (current_s == 0 and current_layer == num_layers and current_devices == 0)
+ assert current_s == 0 and current_layer == num_layers and current_devices == 0
return total_cost, res
-def alpa_dp(num_layers,
- num_devices,
- num_microbatches,
- submesh_choices,
- num_autosharding_configs,
- compute_cost,
- gap=1e-6):
+def alpa_dp(
+ num_layers, num_devices, num_microbatches, submesh_choices, num_autosharding_configs, compute_cost, gap=1e-6
+):
"""Alpa auto stage dynamic programming.
- Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py
+ Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py
Arguments:
submesh_choices: List[(int,int)]
num_autosharding_configs: Max number of t_intra(start_layer, end_layer, LogicalMesh)
compute_cost: np.array(num_layers,num_layers,num_submesh_choices,num_autosharding_configs)
"""
- assert np.shape(compute_cost) == (num_layers, num_layers, len(submesh_choices),
- num_autosharding_configs), "Cost shape wrong."
+ assert np.shape(compute_cost) == (
+ num_layers,
+ num_layers,
+ len(submesh_choices),
+ num_autosharding_configs,
+ ), "Cost shape wrong."
all_possible_stage_costs = np.sort(np.unique(compute_cost))
best_cost = np.inf
best_solution = None
@@ -117,8 +120,9 @@ def alpa_dp(num_layers,
break
if max_stage_cost - last_max_stage_cost < gap:
continue
- cost, solution = alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost,
- max_stage_cost, best_configs)
+ cost, solution = alpa_dp_impl(
+ num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost, max_stage_cost, best_configs
+ )
if cost < best_cost:
best_cost = cost
best_solution = solution
diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py
index 2a5f747fbc238c7799b3076e695030020f491d5b..72f199203a9d12c9dd75869a96065715811afb75 100644
--- a/colossalai/device/device_mesh.py
+++ b/colossalai/device/device_mesh.py
@@ -3,11 +3,19 @@
with some changes. """
import operator
+from dataclasses import dataclass
from functools import reduce
-from typing import List, Tuple
+from typing import Dict, List, Union
import torch
import torch.distributed as dist
+from torch.distributed import ProcessGroup
+
+
+@dataclass
+class ProcessGroupContainer:
+ process_group: ProcessGroup
+ ranks: List[int]
# modified from alpa LogicalDeviceMesh(https://github.com/alpa-projects/alpa/blob/main/alpa/shard_parallel/auto_sharding.py)
@@ -27,223 +35,491 @@ class DeviceMesh:
during initializing the DeviceMesh instance if the init_process_group set to True.
Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group.
(default: False)
- need_flatten(bool, optional): initialize flatten_device_mesh during initializing the DeviceMesh instance if the need_flatten set to True.
+ device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda')
"""
- def __init__(self,
- physical_mesh_id: torch.Tensor,
- mesh_shape: torch.Size = None,
- logical_mesh_id: torch.Tensor = None,
- mesh_alpha: List[float] = None,
- mesh_beta: List[float] = None,
- init_process_group: bool = False,
- need_flatten: bool = True):
- self.physical_mesh_id = physical_mesh_id
+ _DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"}
+
+ def __init__(
+ self,
+ physical_mesh_id: torch.Tensor,
+ mesh_shape: torch.Size = None,
+ logical_mesh_id: torch.Tensor = None,
+ mesh_alpha: List[float] = None,
+ mesh_beta: List[float] = None,
+ init_process_group: bool = False,
+ device: str = "cuda",
+ ):
+ # ============================
+ # Physical & Logical Mesh IDs
+ # ============================
+ self._physical_mesh_id = physical_mesh_id
+ assert physical_mesh_id.dim() == 1, "physical_mesh_id should be a 1D tensor."
+
+ # logical mesh ids can be obtained via two ways
+ # 1. provide physical mesh id and provide mesh shape
+ # 2. directly supply the logical mesh id
+ assert mesh_shape is None or logical_mesh_id is None, (
+ "Only one of mesh_shape and logical_mesh_id can be specified."
+ "Logical mesh IDs are obtained from either mesh_shape + physical_mesh_id or directly from the user-supplied logical_mesh_id"
+ )
+
if logical_mesh_id is None:
- self.mesh_shape = mesh_shape
- self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
+ self._mesh_shape = mesh_shape
+ self._logical_mesh_id = self._physical_mesh_id.reshape(self._mesh_shape)
else:
self._logical_mesh_id = logical_mesh_id
- self.mesh_shape = self._logical_mesh_id.shape
+ self._mesh_shape = self._logical_mesh_id.shape
+
+ # ensure two things:
+ # 1. logical and physical mesh IDs should contain the same elements
+ # 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed
+ assert torch.equal(
+ torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)
+ ), "physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
+ assert (
+ torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel()
+ ), "Found duplicate IDs in the physical_mesh_id and this is not allowed, please check your physical_mesh_id again."
+ assert (
+ torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel()
+ ), "Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
- # map global rank into logical rank
- self.convert_map = {}
- self._global_rank_to_logical_rank_map(self._logical_mesh_id, [])
+ # ===============================================
# coefficient for alpha-beta communication model
+ # alpha is latency and beta is bandwidth
+ # ===============================================
+ # if the values are not provided, we assume they are 1 for simplicity
if mesh_alpha is None:
- mesh_alpha = [1] * len(self.mesh_shape)
+ mesh_alpha = [1] * len(self._mesh_shape)
if mesh_beta is None:
- mesh_beta = [1] * len(self.mesh_shape)
+ mesh_beta = [1] * len(self._mesh_shape)
+
self.mesh_alpha = tuple(mesh_alpha)
self.mesh_beta = tuple(mesh_beta)
- self.init_process_group = init_process_group
- self.need_flatten = need_flatten
- if self.init_process_group:
- self.process_groups_dict = self.create_process_groups_for_logical_mesh()
- if self.need_flatten and self._logical_mesh_id.dim() > 1:
- self.flatten_device_mesh = self.flatten()
- # Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten())
- # self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
- # self.mesh_beta)
+
+ # ensure the alpha and beta have the same shape
+ assert len(self.mesh_alpha) == len(
+ self.mesh_beta
+ ), "mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again."
+
+ # =========================
+ # Device for Process Group
+ # =========================
+ self._device = device
+ self._dist_backend = self._DIST_BACKEND[device]
+
+ # =========================
+ # Process Group Management
+ # =========================
+ # the _global_to_local_rank_mapping is structured as follows
+ # {
+ # : [ , , , ...]
+ # }
+ self._global_to_local_rank_mapping = dict()
+ self._init_global_to_logical_rank_mapping(
+ mapping=self._global_to_local_rank_mapping, tensor=self.logical_mesh_id
+ )
+
+ # create process group
+ self._process_group_dict = {}
+ self._ranks_in_the_process_group = {}
+ self._global_rank_of_current_process = None
+ self._is_initialized = False
+
+ # attribute used to indicate whether this object
+ # is created using DeviceMesh.from_process_group
+ # this attribute can be used to do some check in methods
+ # such get_process_group as no global rank information
+ # is known if created with from_process_group
+ self._is_init_from_process_group = False
+
+ # initialize process group if specified
+ self._init_ranks_in_the_same_group()
+ self._init_process_group = init_process_group
+ if init_process_group:
+ self.init_logical_process_group()
@property
- def shape(self):
- return self.mesh_shape
+ def shape(self) -> torch.Size:
+ """
+ Return the shape of the logical mesh.
+ """
+ return self._mesh_shape
@property
- def num_devices(self):
- return reduce(operator.mul, self.physical_mesh_id.shape, 1)
+ def num_devices(self) -> int:
+ """
+ Return the number of devices contained in the device mesh.
+ """
+ return reduce(operator.mul, self._physical_mesh_id.shape, 1)
@property
- def logical_mesh_id(self):
+ def logical_mesh_id(self) -> torch.Tensor:
+ """
+ Return the logical mesh id.
+ """
return self._logical_mesh_id
- def __deepcopy__(self, memo):
+ @property
+ def is_initialized(self) -> bool:
+ """
+ Return whether the process group is initialized.
+ """
+ return self._is_initialized
+
+ @staticmethod
+ def from_process_group(process_group: Union[ProcessGroup, List[ProcessGroup]]) -> "DeviceMesh":
+ """
+ Create a DeviceMesh instance from the current process group. Please note that the DeviceMesh object created with this method
+ will not have information about the physical mesh id, and thus will not be able to query for other ranks and perform alpha-beta communication.
+
+ Args:
+ process_group (Union[ProcessGroup, List[ProcessGroup]]): the process group or a list of process groups for the device mesh.
+ If the input is a ProcessGroup object, a 1D DeviceMesh object will be created. If the input is a list of ProcessGroup objects,
+ the ProcessGroup at the ith index will correspond to the process group in the ith axis of the device mesh.
+
+ Returns:
+ DeviceMesh: the device mesh instance.
+ """
+
+ def _get_device_by_backend(process_group):
+ """
+ Get the device type given a process group's backend.
+ """
+ backend = dist.get_backend(process_group)
+ for _device, _backend in DeviceMesh._DIST_BACKEND.items():
+ if _backend == backend:
+ return _device
+ return None
+
+ if isinstance(process_group, ProcessGroup):
+ process_group = [process_group]
+
+ # get mesh shape
+ mesh_shape = [dist.get_world_size(pg) for pg in process_group]
+
+ # get device
+ device_list = [_get_device_by_backend(pg) for pg in process_group]
+
+ # make sure all devices are the same
+ assert all(
+ [device == device_list[0] for device in device_list]
+ ), "All devices should be the same, please check your input process groups are created with the same distributed backend."
+
+ # create a fake physical mesh id
+ # as we only get the process group associated with the current process,
+ # we cannot get the global ranks for all processes in the mesh
+ # therefore, we only use this fake physical mesh id to create the device mesh
+ # and will remove this fake physical mesh id later
+ fake_physical_mesh_id = torch.arange(reduce(operator.mul, mesh_shape, 1))
+
+ # create the device mesh
+ device_mesh = DeviceMesh(physical_mesh_id=fake_physical_mesh_id, mesh_shape=mesh_shape, device=device_list[0])
+
+ # hack the device attribute
+ device_mesh._physical_mesh_id = None
+ device_mesh._logical_mesh_id = None
+ device_mesh._global_rank_of_current_process = dist.get_rank()
+ device_mesh._is_initialized = False
+ device_mesh._process_group_dict = {
+ device_mesh._global_rank_of_current_process: {axis: pg for axis, pg in enumerate(process_group)}
+ }
+
+ return device_mesh
+
+ def get_process_group(self, axis: int, global_rank: int = None) -> ProcessGroup:
+ """
+ Return the process group on the specified axis.
+
+ Args:
+ axis (int): the axis of the process group.
+ global_rank (int, optional): the global rank of the process group. If not specified, the current process is used. (default: None)
+ """
+ if global_rank is None:
+ global_rank = self._global_rank_of_current_process
+ elif self._is_init_from_process_group:
+ raise RuntimeError(
+ "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
+ )
+ return self._process_group_dict[global_rank][axis]
+
+ def get_process_group_for_all_axes(self, global_rank: int = None) -> Dict[int, ProcessGroup]:
+ """
+ Return the process groups for all axes.
+
+ Args:
+ global_rank (int, optional): the global rank of the process
+ """
+ if global_rank is None:
+ global_rank = self._global_rank_of_current_process
+ elif self._is_init_from_process_group:
+ raise RuntimeError(
+ "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
+ )
+ return self._process_group_dict[global_rank]
+
+ def get_ranks_in_process_group(self, axis: int, global_rank: int = None) -> List[int]:
+ """
+ Return the ranks in the process group on the specified axis.
+
+ Args:
+ axis (int): the axis of the process group.
+ global_rank (int, optional): the global rank of the process
+ """
+ if global_rank is None:
+ global_rank = self._global_rank_of_current_process
+ elif self._is_init_from_process_group:
+ raise RuntimeError(
+ "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
+ )
+ return self._ranks_in_the_process_group[global_rank][axis]
+
+ def __deepcopy__(self, memo) -> "DeviceMesh":
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
- if k != 'process_groups_dict':
+ if k != "_process_group_dict":
setattr(result, k, __import__("copy").deepcopy(v, memo))
else:
+ # process group cannot be copied
+ # thus, we share them directly
setattr(result, k, v)
-
return result
- def flatten(self):
+ def _init_global_to_logical_rank_mapping(
+ self, mapping: Dict, tensor: torch.Tensor, index_list: List[int] = []
+ ) -> Dict[int, List[int]]:
"""
- Flatten the logical mesh into an effective 1d logical mesh,
+ Build a global rank to local rank mapping for each process group in different axis in the logical device mesh.
+
+ Args:
+ mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh.
+ tensor (torch.Tensor): the tensor that contains the logical mesh ids.
+ index_list (List[int])
+
+ Returns:
+ mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh.
+ The value is a list of integers and each integer represents the local rank in the indexed axis.
"""
- flatten_mesh_shape_size = len(self.mesh_shape)
- flatten_mesh_shape = [self.num_devices]
- return DeviceMesh(self.physical_mesh_id,
- tuple(flatten_mesh_shape),
- mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
- mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
- init_process_group=self.init_process_group,
- need_flatten=False)
-
- def _global_rank_to_logical_rank_map(self, tensor, index_list):
- '''
- This method is a helper function to build convert_map recursively.
- '''
for index, inner_tensor in enumerate(tensor):
+ # index means the local rank in the current axis
+ # inner_tensor refers to the processes with the same local rank
+
if inner_tensor.numel() == 1:
- self.convert_map[int(inner_tensor)] = index_list + [index]
+ # if the inner_tensor only has one element, it means that
+ # it already reaches the last axis
+ # we append its local_rank in the last axis to the index_list
+ # and assign to the mapping
+ # the value of the mapping is the the local rank at the indexed axis of the device mesh
+ mapping[int(inner_tensor)] = index_list + [index]
else:
- self._global_rank_to_logical_rank_map(inner_tensor, index_list + [index])
+ # we recursively go into the function until we reach the last axis
+ # meanwhile, we should add the local rank in the current axis in the index_list
+ self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index])
- def create_process_groups_for_logical_mesh(self):
- '''
+ def init_logical_process_group(self):
+ """
This method is used to initialize the logical process groups which will be used in communications
among logical device mesh.
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
- '''
- process_groups_dict = {}
- check_duplicate_list = []
- global_rank_flatten_list = self.physical_mesh_id.view(-1).tolist()
+ """
+ # sanity check
+ assert (
+ dist.is_initialized
+ ), "The torch.distributed should be initialized before calling init_logical_process_group"
+ assert (
+ not self._is_initialized
+ ), "The logical process group has been initialized, do not call init_logical_process_group twice"
+
+ # update the global rank of the current process
+ self._global_rank_of_current_process = dist.get_rank()
+ duplicate_check_list = []
+
+ # flatten the global ranks to 1D list
+ global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist()
+
for global_rank in global_rank_flatten_list:
- process_groups = self.global_rank_to_process_groups_with_global_rank(global_rank)
- for axis, process_group in process_groups.items():
- if axis not in process_groups_dict:
- process_groups_dict[axis] = []
- if process_group not in check_duplicate_list:
- check_duplicate_list.append(process_group)
- process_group_handler = dist.new_group(process_group)
- process_groups_dict[axis].append((process_group, process_group_handler))
-
- return process_groups_dict
-
- def global_rank_to_logical_rank(self, rank):
- return self.convert_map[rank]
-
- def global_rank_to_process_groups_with_logical_rank(self, rank):
- '''
- Give a global rank and return all logical process groups of this rank.
- for example:
- physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
- mesh_shape = (4, 4)
- # [[0, 1, 2, 3],
- # [4, 5, 6, 7],
- # [8, 9, 10,11],
- # [12,13,14,15]]
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
- print(device_mesh.global_rank_to_process_groups_with_logical_rank(0))
- output:
- # key is axis name
- # value is a list of logical ranks in same axis with rank 0
- {0: [[0, 0], [1, 0], [2, 0], [3, 0]], 1: [[0, 0], [0, 1], [0, 2], [0, 3]]}
- '''
- process_groups = {}
- for d in range(self.logical_mesh_id.dim()):
- for replacer in range(self.logical_mesh_id.shape[d]):
- if d not in process_groups:
- process_groups[d] = []
- process_group_member = self.convert_map[rank].copy()
- process_group_member[d] = replacer
- process_groups[d].append(process_group_member)
- return process_groups
-
- def global_rank_to_process_groups_with_global_rank(self, rank):
- '''
- Give a global rank and return all process groups of this rank.
- for example:
- physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
- mesh_shape = (4, 4)
- # [[0, 1, 2, 3],
- # [4, 5, 6, 7],
- # [8, 9, 10,11],
- # [12,13,14,15]]
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
- print(device_mesh.global_rank_to_process_groups_with_global_rank(0))
- output:
- # key is axis name
- # value is a list of global ranks in same axis with rank 0
- {0: [0, 4, 8, 12], 1: [0, 1, 2, 3]}
- '''
- logical_process_groups = self.global_rank_to_process_groups_with_logical_rank(rank)
- process_groups = {}
- for dim, logical_ranks in logical_process_groups.items():
- process_groups[dim] = []
- for logical_rank in logical_ranks:
- for g_rank, l_rank in self.convert_map.items():
- if l_rank == logical_rank:
- process_groups[dim].append(g_rank)
- return process_groups
+ # find the other ranks which are in the same process group as global_rank
+ ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank)
+
+ for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items():
+ # skip duplicated process group creation
+ if ranks_in_same_group in duplicate_check_list:
+ continue
+
+ # create the process group
+ pg_handler = dist.new_group(ranks=ranks_in_same_group, backend=self._dist_backend)
+
+ # keep this process group in the process_groups_dict
+ for rank in ranks_in_same_group:
+ if rank not in self._process_group_dict:
+ self._process_group_dict[rank] = dict()
+ self._process_group_dict[rank][axis] = pg_handler
+
+ # update the init flag
+ # we only allow init for once
+ self._is_initialized = True
+
+ def _init_ranks_in_the_same_group(self):
+ """
+ This method is used to initialize the ranks_in_the_same_group dictionary.
+ """
+ # flatten the global ranks to 1D list
+ global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist()
+
+ for global_rank in global_rank_flatten_list:
+ # find the other ranks which are in the same process group as global_rank
+ ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank)
+
+ for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items():
+ # create dict for each rank
+ if global_rank not in self._process_group_dict:
+ self._ranks_in_the_process_group[global_rank] = dict()
+
+ # keep this process group in the process_groups_dict
+ self._ranks_in_the_process_group[global_rank][axis] = ranks_in_same_group
+
+ def global_rank_to_local_rank(self, rank: int, axis: int = None) -> Union[List[int], int]:
+ """
+ Return the local rank of the given global rank in the logical device mesh.
+
+ Args:
+ rank (int): the global rank in the logical device mesh.
+ axis (int): the axis of the logical device mesh.
+ """
+ if self._is_init_from_process_group:
+ raise RuntimeError(
+ "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
+ )
+
+ local_ranks = self._global_to_local_rank_mapping[rank]
+ if axis:
+ return local_ranks[axis]
+ else:
+ return local_ranks
+
+ def _collate_global_ranks_in_same_process_group(self, global_rank):
+ """
+ Give a global rank and return all global ranks involved in its associated process group in each axis.
+
+ Example:
+
+ ```python
+ physical_mesh_id = torch.arange(0, 16)
+ mesh_shape = (4, 4)
+
+ # logical mesh will look like
+ # [[0, 1, 2, 3],
+ # [4, 5, 6, 7],
+ # [8, 9, 10,11],
+ # [12,13,14,15]]
+
+ device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
+ print(device_mesh.collate_global_ranks_in_same_process_group(0))
+
+ # key is axis name
+ # value is a list of global ranks in same axis with rank 0
+ # output will look like
+ # {
+ 0: [0, 4, 8, 12],
+ 1: [0, 1, 2, 3]
+ # }
+ """
+ # We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping
+ # for self._global_to_local_rank_mapping
+ # the key is the global rank
+ # the value is the list of local ranks corresponding to the global rank with respect of different axes
+ # we can see the list of local ranks as the process coordinates for simplicity
+ # the key and value are all unique, therefore,
+ # we can also to use the coordinates to find the global rank
+
+ # =========================================================================
+ # Step 1
+ # find all the process_coordinates for processes in the same process group
+ # as the given global rank
+ # =========================================================================
+
+ # each
+ processes_in_the_same_process_group = {}
+
+ for dim in range(self.logical_mesh_id.dim()):
+ # iterate over the dimension size so that we can include all processes
+ # in the same process group in the given axis
+ # the _local_rank refers to the local rank of the current process
+ for _local_rank in range(self.logical_mesh_id.shape[dim]):
+ # if this dimension is not initialized yet,
+ # initialize it with an empty array
+ if dim not in processes_in_the_same_process_group:
+ processes_in_the_same_process_group[dim] = []
+
+ # get the local rank corresponding to the global rank
+ process_coordinates = self._global_to_local_rank_mapping[global_rank].copy()
+
+ # replace the local rank in the given dimension with the
+ # local rank of the current process iterated
+ process_coordinates[dim] = _local_rank
+ processes_in_the_same_process_group[dim].append(process_coordinates)
+
+ # =================================================================
+ # Step 2
+ # Use local rank combination to find its corresponding global rank
+ # =================================================================
+ # the key of the dict is the axis
+ # the value is the list of global ranks which are in the same process group as the given global rank
+ global_pg_ranks = {}
+ for dim, coordinates_of_all_processes in processes_in_the_same_process_group.items():
+ global_pg_ranks[dim] = []
+ for process_coordinates in coordinates_of_all_processes:
+ # find the global rank by local rank combination
+ for _global_rank, _process_coordinates in self._global_to_local_rank_mapping.items():
+ if process_coordinates == _process_coordinates:
+ global_pg_ranks[dim].append(_global_rank)
+ return global_pg_ranks
+
+ def flatten(self):
+ """
+ Flatten the logical mesh into an effective 1d logical mesh,
+ """
+ if self._is_init_from_process_group:
+ raise RuntimeError(
+ "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known."
+ )
+
+ flatten_mesh_shape_size = len(self._mesh_shape)
+ flatten_mesh_shape = [self.num_devices]
+ return DeviceMesh(
+ self._physical_mesh_id,
+ tuple(flatten_mesh_shape),
+ mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
+ mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
+ init_process_group=self._init_process_group,
+ )
def all_gather_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
- return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes +
- 0.1)
+ return self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.1
def all_reduce_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
- return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes +
- 0.01)
+ return (
+ self.mesh_alpha[mesh_dim]
+ + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes
+ + 0.01
+ )
def reduce_scatter_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
- return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes +
- 0.001)
+ return (
+ self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.001
+ )
def all_to_all_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
penalty_factor = num_devices / 2.0
- return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] *
- (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001)
-
-
-class FlattenDeviceMesh(DeviceMesh):
-
- def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None):
- super().__init__(physical_mesh_id,
- mesh_shape,
- mesh_alpha,
- mesh_beta,
- init_process_group=False,
- need_flatten=False)
- # Different from flatten(), mesh_shape leaves unchanged, mesh_alpha and mesh_beta are scalars
- self.mesh_alpha = max(self.mesh_alpha)
- self.mesh_beta = min(self.mesh_beta)
- # Different from original process_groups_dict, rank_list is not stored
- self.process_number_dict = self.create_process_numbers_for_logical_mesh()
-
- def create_process_numbers_for_logical_mesh(self):
- '''
- Build 1d DeviceMesh in column-major(0) and row-major(1)
- for example:
- mesh_shape = (2,4)
- # [[0, 1, 2, 3],
- # [4, 5, 6, 7]]
- # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
- '''
- num_devices = reduce(operator.mul, self.mesh_shape, 1)
- process_numbers_dict = {}
- process_numbers_dict[0] = torch.arange(num_devices).reshape(self.mesh_shape).transpose(1, 0).flatten().tolist()
- process_numbers_dict[1] = torch.arange(num_devices).reshape(self.mesh_shape).flatten().tolist()
- return process_numbers_dict
-
- def mix_gather_cost(self, num_bytes):
- num_devices = reduce(operator.mul, self.mesh_shape, 1)
- return (self.mesh_alpha + self.mesh_beta * (num_devices - 1) / num_devices * num_bytes + 0.1)
+ return (
+ self.mesh_alpha[mesh_dim]
+ + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor
+ + 0.001
+ )
diff --git a/colossalai/engine/__init__.py b/colossalai/engine/__init__.py
deleted file mode 100644
index 158796befb312755ed92f77f7828557f55800e4c..0000000000000000000000000000000000000000
--- a/colossalai/engine/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from ._base_engine import Engine
-from .gradient_handler import *
-
-__all__ = ['Engine']
diff --git a/colossalai/engine/gradient_accumulation/__init__.py b/colossalai/engine/gradient_accumulation/__init__.py
deleted file mode 100644
index 4cb6f4ad7384dda6136d98a0a73521e37d4027ba..0000000000000000000000000000000000000000
--- a/colossalai/engine/gradient_accumulation/__init__.py
+++ /dev/null
@@ -1,57 +0,0 @@
-from typing import Iterable, List
-
-import torch.nn as nn
-from torch.optim import Optimizer
-from torch.optim.lr_scheduler import _LRScheduler
-
-from colossalai.engine import BaseGradientHandler
-
-from ._gradient_accumulation import (
- GradAccumDataloader,
- GradAccumGradientHandler,
- GradAccumLrSchedulerByStep,
- GradAccumOptimizer,
-)
-
-__all__ = [
- 'accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer', 'GradAccumLrSchedulerByStep',
- 'GradAccumGradientHandler'
-]
-
-
-def accumulate_gradient(model: nn.Module,
- optimizer: Optimizer,
- dataloader: Iterable,
- accumulate_size: int,
- gradient_handlers: List[BaseGradientHandler] = None,
- lr_scheduler: _LRScheduler = None):
- r"""Turning model, optimizer, dataloader into corresponding object for gradient accumulation.
-
- Args:
- model (:class:`torch.nn.Module`): your model object for gradient accumulation.
- optimizer (:class:`torch.optim.Optimizer`): your optimizer object for gradient accumulation.
- dataloader (:class:`torch.utils.data.DataLoader` or iterable objects):
- your dataloader object, would be called like iter(dataloader)
- accumulate_size (int): the number of steps to accumulate gradients
- gradient_handlers (List[:class:`colossalai.engine.BaseGradientHandler`]):
- list of gradient handler objects. Default is None.
- lr_scheduler (`torch.optim.lr_scheduler` or `colossalai.nn.lr_scheduler`):
- your ``lr_scheduler`` object for gradient accumulation. Defaults to None.
-
- More details about `gradient_handlers` could be found in
- `Gradient_handler `_.
-
- More details about `lr_scheduler` could be found
- `lr_scheduler `_. and
- `how to adjust learning rate `_.
- """
- optimizer = GradAccumOptimizer(optimizer, accumulate_size=accumulate_size, model=model)
- dataloader = GradAccumDataloader(dataloader, accumulate_size=accumulate_size)
-
- if gradient_handlers is not None:
- gradient_handlers = [GradAccumGradientHandler(handler, accumulate_size) for handler in gradient_handlers]
-
- if lr_scheduler is not None:
- lr_scheduler = GradAccumLrSchedulerByStep(lr_scheduler, accumulate_size=accumulate_size)
-
- return optimizer, dataloader, gradient_handlers, lr_scheduler
diff --git a/colossalai/engine/gradient_handler/__init__.py b/colossalai/engine/gradient_handler/__init__.py
deleted file mode 100644
index 2dea768bad7ecf1feed8bae69f733cda943509b5..0000000000000000000000000000000000000000
--- a/colossalai/engine/gradient_handler/__init__.py
+++ /dev/null
@@ -1,11 +0,0 @@
-from ._base_gradient_handler import BaseGradientHandler
-from ._data_parallel_gradient_handler import DataParallelGradientHandler
-from ._moe_gradient_handler import MoeGradientHandler
-from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler
-from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler
-from ._zero_gradient_handler import ZeROGradientHandler
-
-__all__ = [
- 'BaseGradientHandler', 'DataParallelGradientHandler', 'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler',
- 'MoeGradientHandler', 'SequenceParallelGradientHandler'
-]
diff --git a/colossalai/engine/schedule/__init__.py b/colossalai/engine/schedule/__init__.py
deleted file mode 100644
index 0f2c039d7057324676d30938c6ec112279077b61..0000000000000000000000000000000000000000
--- a/colossalai/engine/schedule/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from ._base_schedule import BaseSchedule
-from ._non_pipeline_schedule import NonPipelineSchedule
-from ._pipeline_schedule import InterleavedPipelineSchedule, PipelineSchedule, get_tensor_shape
-
-__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule', 'get_tensor_shape']
diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py
deleted file mode 100644
index 38175fe0941c1c053bf91fe1df558ee9e763c360..0000000000000000000000000000000000000000
--- a/colossalai/engine/schedule/_pipeline_schedule.py
+++ /dev/null
@@ -1,833 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import inspect
-from typing import Callable, List, Tuple, Union
-
-import torch.cuda
-
-import colossalai.communication as comm
-from colossalai.amp.naive_amp import NaiveAMPModel
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.logging import get_dist_logger
-from colossalai.utils import switch_virtual_pipeline_parallel_rank
-from colossalai.utils.cuda import get_current_device
-
-from ._base_schedule import BaseSchedule
-
-
-def get_tensor_shape():
- if hasattr(gpc.config, 'TENSOR_SHAPE'):
- return gpc.config.TENSOR_SHAPE
-
- if not gpc.is_initialized(ParallelMode.PIPELINE):
- return None
-
- if hasattr(gpc.config, 'SEQ_LENGTH') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(
- gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'HIDDEN_SIZE'):
- if gpc.is_initialized(ParallelMode.DATA):
- dp_size = gpc.get_world_size(ParallelMode.DATA)
- else:
- dp_size = 1
- if gpc.is_initialized(ParallelMode.SEQUENCE):
- seq_size = gpc.get_world_size(ParallelMode.SEQUENCE)
- else:
- seq_size = 1
-
- tensor_shape = (gpc.config.SEQ_LENGTH // seq_size,
- gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES, gpc.config.HIDDEN_SIZE)
- return tensor_shape
- else:
- return None
-
-
-def pack_return_tensors(return_tensors):
- output, label = tuple(zip(*return_tensors))
- if isinstance(output[0], torch.Tensor):
- output = torch.cat(output, dim=0)
- elif isinstance(output[0], (list, tuple)):
- output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output))
- else:
- raise TypeError(f'Output of model must be tensor or list/tuple of tensors')
- if isinstance(label[0], torch.Tensor):
- label = torch.cat(label, dim=0)
- else:
- merged_label = {k: [] for k in label[0].keys()}
- for d in label:
- for k, v in d.items():
- merged_label[k].append(v)
- label = {k: torch.cat(v, dim=0) for k, v in merged_label.items()}
- return output, label
-
-
-class PipelineSchedule(BaseSchedule):
- """A helper schedule class for pipeline parallelism running environment.
- It uses non-interleaved 1F1B strategy. Other properties are similar as
- :class:`NonPipelineSchedule`.
-
- Args:
- num_microbatches (int): The number of microbatches.
- data_process_func (Callable, optional):
- The preprocessing function which receives a batch of data, and it will be executed in `load_batch`.
- tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
- scatter_gather_tensors (bool, optional):
- If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
-
- Example:
-
- # this shows an example of customized data_process_func
- def data_process_func(stage_output, dataloader_output):
- output1, output2 = stage_output
- item1, item2, item3 = dataloader_output
-
- # assume item2 is not needed
- data = (output1, output2, item1)
- label = item3
- return data, label
-
- """
-
- def __init__(self,
- num_microbatches,
- data_process_func: Callable = None,
- tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
- scatter_gather_tensors: bool = False):
-
- # we need to make sure that the signature of the data_process_func is valid
- if data_process_func:
- sig = inspect.signature(data_process_func)
- assert len(sig.parameters) == 2, \
- 'The data_process_func only takes in two parameters for NonPipelineSchedule, ' \
- 'which is the tensors passed by the previous pipeline stage and the dataloader output from this stage, ' \
- 'i.e. data_process_func(stage_output, dataloader_output).'
-
- super().__init__(data_process_func=data_process_func)
-
- assert num_microbatches > 0, f'expected num_microbatches to be larger then 1, but got {num_microbatches}'
-
- self.num_microbatches = num_microbatches
- self.dtype = torch.float
- assert not isinstance(tensor_shape,
- int), "tensor_shape type should be one of Union[torch.Size, List[int], Tuple[int]]."
- if tensor_shape is None:
- self.tensor_shape = tensor_shape
- elif isinstance(tensor_shape, torch.Size):
- self.tensor_shape = tensor_shape
- else:
- self.tensor_shape = torch.Size(tensor_shape)
- self.scatter_gather_tensors = False
- if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1:
- self.scatter_gather_tensors = scatter_gather_tensors
- self._logger = get_dist_logger()
-
- # cache for the batch data
- self.batch_data = None
-
- def load_batch(self, data_iter):
- # Pipeline schedule just puts data in memory
- batch_data = super().load_batch(data_iter, to_gpu=False)
- self.microbatch_offset = 0
- assert self.batch_size % self.num_microbatches == 0, \
- "Batch size should divided by the number of microbatches"
- self.microbatch_size = self.batch_size // self.num_microbatches
- self.batch_data = batch_data
-
- def _get_data_slice(self, data, offset):
- if isinstance(data, torch.Tensor):
- return data[offset:offset + self.microbatch_size]
- elif isinstance(data, (list, tuple)):
- data_dict = {}
- for element in data:
- if isinstance(element, dict):
- data_dict.update({k: v[offset:offset + self.microbatch_size] for k, v in element.items()})
- elif data_dict:
- data_dict['label'] = element[offset:offset + self.microbatch_size]
- if data_dict:
- return data_dict
- return [val[offset:offset + self.microbatch_size] for val in data]
- elif isinstance(data, dict):
- return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()}
- else:
- raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
-
- def load_micro_batch(self):
- mciro_batch_data = self._get_data_slice(self.batch_data, self.microbatch_offset)
- self.microbatch_offset += self.microbatch_size
- return self._move_to_device(mciro_batch_data)
-
- def pre_processing(self, engine):
- from colossalai.zero.legacy import ShardedModelV2
-
- # TODO: remove this after testing new zero with pipeline parallelism
- model = engine.model
- if isinstance(model, NaiveAMPModel):
- self.dtype = torch.half
- model = model.model
- if isinstance(model, ShardedModelV2):
- self.dtype = torch.half
- model = model.module
- # sig = inspect.signature(model.forward)
- # for p in sig.parameters.values():
- # assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
-
- @staticmethod
- def _call_engine(model, data):
- if data is not None:
- if isinstance(data, torch.Tensor):
- return model(data)
- elif isinstance(data, (list, tuple)):
- return model(*data)
- elif isinstance(data, dict):
- stage_output = None
- if 'stage_output' in data:
- stage_output = data.pop('stage_output')
- if stage_output is None:
- return model(**data)
- elif isinstance(stage_output, torch.Tensor):
- return model(stage_output, **data)
- elif isinstance(stage_output, (tuple, list)):
- return model(*stage_output, **data)
- else:
- raise TypeError(
- f"Expected stage_output to be of type torch.Tensor, list, or tuple, but got {type(stage_output)}"
- )
- else:
- raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
-
- def _get_actual_forward_func(self, module):
- if isinstance(module, NaiveAMPModel):
- sig = inspect.signature(module.model.forward)
- elif hasattr(module, 'colo_attr'):
- sig = inspect.signature(module.module.forward)
- else:
- sig = inspect.signature(module.forward)
- return sig
-
- def _get_data_label_for_current_step(self, stage_output, micro_batch_data, criterion, model):
- if self.data_process_func:
- # use customized function to get data and label
- data, label = self.data_process_func(stage_output, micro_batch_data)
- else:
- if isinstance(micro_batch_data, (tuple, list)):
- if gpc.is_first_rank(ParallelMode.PIPELINE):
- # for the first stage, we use the data from the
- # dataloader output by default
- data, label = micro_batch_data
- else:
- # for non-first stage, we use the output passed
- # by the previous as the model input
- data = stage_output
- _, label = micro_batch_data
- elif isinstance(micro_batch_data, dict):
- data = {}
- data['stage_output'] = stage_output
- if 'label' in micro_batch_data:
- label = micro_batch_data.pop('label')
- else:
- label = None
- load_data = micro_batch_data
- data.update(load_data)
- return data, label
-
- def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
- """Forward step for passed-in model. If it is the first stage, the input tensor
- is obtained from data_iterator, otherwise the passed-in input_obj is used.
- Returns output tensor. This is a helper function and can be ignored by users.
-
- Args:
- engine (colossalai.engine.Engine): Colossalai engine for training and inference.
- input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
- return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
- return_output_label (bool, optional): Whether returns output labels.
- accum_loss (optional): Where accumulated loss stores.
- Returns:
- Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage.
- """
- micro_batch_data = self.load_micro_batch()
-
- data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, engine.criterion, engine.model)
-
- output_obj = self._call_engine(engine.model, data)
-
- if gpc.is_last_rank(ParallelMode.PIPELINE):
- if return_output_label:
- return_tensors.append((output_obj, label))
- if accum_loss is not None:
- loss_reduced = self._call_engine_criterion(engine, output_obj, label) / self.num_microbatches
- accum_loss.add_(loss_reduced.detach())
- return loss_reduced
- else:
- # forward only, it's useless since backward is not needed
- return output_obj
- else:
- if isinstance(output_obj, torch.Tensor):
- self._logger.debug(
- f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}'
- )
- return output_obj
-
- def _backward_step(self, engine, input_obj, output_obj, output_obj_grad):
- """Backward step through the passed-in output tensor. If it is the last stage, the
- output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor.
- Returns the gradients with respect to the input tensor (None if first stage).
- This is a helper function and can be ignored by users.
-
- Args:
- engine (colossalai.engine.Engine): Colossalai engine for training and inference.
- input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): input tensor for this pipeline stage.
- output_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): output tensor for this pipeline stage.
- output_obj_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): gradient of output tensor for this pipeline stage.
-
- Returns:
- Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: gradient of input tensor.
- """
-
- # Retain the grad on the input_obj.
- if input_obj is not None:
- if isinstance(input_obj, torch.Tensor):
- input_obj.retain_grad()
- else:
- for in_tensor in input_obj:
- if in_tensor is not None:
- in_tensor.retain_grad()
- # Backward pass.
- if output_obj_grad is None:
- engine.backward(output_obj)
- else:
- engine.backward_by_grad(output_obj, output_obj_grad)
-
- # Collect the grad of the input_obj.
- input_obj_grad = None
- if input_obj is not None:
- if isinstance(input_obj, torch.Tensor):
- input_obj_grad = input_obj.grad
- else:
- input_obj_grad = []
- for in_tensor in input_obj:
- input_obj_grad.append(in_tensor.grad)
-
- return input_obj_grad
-
- def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
- """Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
- Returns a tuple with losses if the last stage, an empty tuple otherwise.
-
- Args:
- engine (colossalai.engine.Engine): Colossalai engine for training and inference.
- data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
- forward_only (bool, optional):
- Whether run forward step only. Default is false. If true, no backward will be run.
- return_loss (bool, optional): Whether returns the loss value. Default is true.
- return_output_label (bool, optional): If False, the output and label won't be returned.
-
- Returns:
- Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
- """
-
- assert forward_only or return_loss, \
- 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
- self.load_batch(data_iter)
- num_warmup_microbatches = \
- (gpc.get_world_size(ParallelMode.PIPELINE)
- - gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
- num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches)
- num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
-
- # Input, output tensors only need to be saved when doing backward passes
- input_objs = None
- output_objs = None
- if not forward_only:
- input_objs = []
- output_objs = []
- return_tensors = []
- if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
- accum_loss = torch.zeros(1, device=get_current_device())
- else:
- accum_loss = None
- # Used for tensor meta information communication
- ft_shapes = self.tensor_shape
- bt_shapes = None
- fs_checker = self.tensor_shape is None
-
- # Run warmup forward passes.
- for i in range(num_warmup_microbatches):
- if not gpc.is_first_rank(ParallelMode.PIPELINE):
- ft_shapes = comm.recv_obj_meta(ft_shapes)
- input_obj = comm.recv_forward(ft_shapes,
- dtype=self.dtype,
- scatter_gather_tensors=self.scatter_gather_tensors)
- output_obj = self._forward_step(engine,
- input_obj,
- return_tensors,
- return_output_label=return_output_label,
- accum_loss=accum_loss)
- if not gpc.is_last_rank(ParallelMode.PIPELINE):
- if isinstance(output_obj, torch.Tensor):
- bt_shapes = output_obj.shape
- else:
- bt_shapes = []
- for out_tensor in output_obj:
- bt_shapes.append(out_tensor.shape)
- fs_checker = comm.send_obj_meta(output_obj, fs_checker)
- comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
-
- if not forward_only:
- input_objs.append(input_obj)
- output_objs.append(output_obj)
-
- # Before running 1F1B, need to receive first forward tensor.
- # If all microbatches are run in warmup / cooldown phase, then no need to
- # receive this tensor here.
- if num_microbatches_remaining > 0:
- if not gpc.is_first_rank(ParallelMode.PIPELINE):
- ft_shapes = comm.recv_obj_meta(ft_shapes)
- input_obj = comm.recv_forward(ft_shapes,
- dtype=self.dtype,
- scatter_gather_tensors=self.scatter_gather_tensors)
-
- # Run 1F1B in steady state.
- for i in range(num_microbatches_remaining):
- last_iteration = (i == (num_microbatches_remaining - 1))
-
- output_obj = self._forward_step(engine,
- input_obj,
- return_tensors,
- return_output_label=return_output_label,
- accum_loss=accum_loss)
- if forward_only:
- comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
-
- if not last_iteration:
- input_obj = comm.recv_forward(ft_shapes,
- dtype=self.dtype,
- scatter_gather_tensors=self.scatter_gather_tensors)
-
- else:
- output_obj_grad = comm.send_forward_recv_backward(output_obj,
- bt_shapes,
- dtype=self.dtype,
- scatter_gather_tensors=self.scatter_gather_tensors)
-
- # Add input_obj and output_obj to end of list.
- input_objs.append(input_obj)
- output_objs.append(output_obj)
-
- # Pop output_obj and output_obj from the start of the list for
- # the backward pass.
- input_obj = input_objs.pop(0)
- output_obj = output_objs.pop(0)
-
- input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
-
- if last_iteration:
- input_obj = None
- comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
- else:
- input_obj = comm.send_backward_recv_forward(input_obj_grad,
- ft_shapes,
- dtype=self.dtype,
- scatter_gather_tensors=self.scatter_gather_tensors)
-
- # Run cooldown backward passes.
- if not forward_only:
- for i in range(num_warmup_microbatches):
- input_obj = input_objs.pop(0)
- output_obj = output_objs.pop(0)
-
- output_obj_grad = comm.recv_backward(bt_shapes,
- dtype=self.dtype,
- scatter_gather_tensors=self.scatter_gather_tensors)
-
- input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
-
- comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
-
- if len(return_tensors) > 0:
- output, label = pack_return_tensors(return_tensors)
- return output, label, accum_loss
- else:
- return None, None, accum_loss
-
-
-class InterleavedPipelineSchedule(PipelineSchedule):
-
- def __init__(self,
- num_microbatches: int,
- num_model_chunks: int,
- data_process_func: Callable = None,
- tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
- scatter_gather_tensors: bool = False):
- """A helper schedule class for pipeline parallelism running environment.
- It uses interleaved 1F1B strategy. Other properties are similar as
- :class:`NonPipelineSchedule`.
-
- Args:
- num_microbatches (int): The number of microbatches.
- num_model_chunks (int): The number of model chunks.
- data_process_func (Callable, optional):
- The preprocessing function which receives a batch of data, and it will be executed in `load_batch`.
- tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
- scatter_gather_tensors (bool, optional):
- If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
- """
- assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \
- 'num_microbatches must be an integer multiple of pipeline parallel world size'
- assert isinstance(num_model_chunks, int) and num_model_chunks > 0, \
- f'expected num_model_chunks to be an integer and larger than 0, but got {num_model_chunks}'
- super().__init__(num_microbatches,
- data_process_func=data_process_func,
- tensor_shape=tensor_shape,
- scatter_gather_tensors=scatter_gather_tensors)
- gpc.set_virtual_pipeline_parallel_size(num_model_chunks)
- gpc.set_virtual_pipeline_parallel_rank(0)
- self.num_model_chunks = num_model_chunks
-
- def pre_processing(self, engine):
- from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
- if isinstance(engine.model, ShardedModelV2):
- self.dtype = torch.half
- elif isinstance(engine.model[0], NaiveAMPModel):
- self.dtype = torch.half
- for model in engine.model:
- if isinstance(model, NaiveAMPModel):
- model = model.model
- sig = inspect.signature(model.forward)
- for p in sig.parameters.values():
- assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
-
- def load_batch(self, data_iter):
- super().load_batch(data_iter)
- # overwrite microbatch_offset, since model chunks load the same microbatch, and should tract the offset
- self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
-
- def load_micro_batch(self, model_chunk_id):
- data = self._get_data_slice(self.batch_data, self.microbatch_offset[model_chunk_id])
- self.microbatch_offset[model_chunk_id] += self.microbatch_size
- return self._move_to_device(data)
-
- def _forward_step(self,
- engine,
- model_chunk_id,
- input_obj,
- return_tensors,
- return_output_label=True,
- accum_loss=None):
- """Forward step for passed-in model. If it is the first stage, the input tensor
- is obtained from data_iterator, otherwise the passed-in input_obj is used.
- Returns output tensor. This is a helper function and can be ignored by users.
-
- Args:
- engine (colossalai.engine.Engine): Colossalai engine for training and inference.
- model_chunk_id (int): The id of model chunks.
- input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
- return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
- return_output_label (bool, optional): Whether returns output labels.
- accum_loss (optional): Where accumulated loss stores.
- Returns:
- Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage.
- """
- micro_batch_data = self.load_micro_batch(model_chunk_id)
- data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, engine.criterion,
- engine.model[model_chunk_id])
-
- output_obj = self._call_engine(engine.model[model_chunk_id], data)
-
- if gpc.is_pipeline_last_stage():
- if return_output_label:
- return_tensors.append((output_obj, label))
- if accum_loss is not None:
- loss_reduced = self._call_engine_criterion(engine, output_obj, label) / self.num_microbatches
- accum_loss.add_(loss_reduced.detach())
- return loss_reduced
- else:
- # forward only, it's useless since backward is not needed
- return output_obj
- else:
- if isinstance(output_obj, torch.Tensor):
- self._logger.debug(
- f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}'
- )
- return output_obj
-
- def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
- """Run interleaved 1F1B schedule (model split into model chunks), with
- communication between pipeline stages as needed.
-
- Args:
- engine (colossalai.engine.Engine): Colossalai engine for training and inference.
- data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
- forward_only (bool, optional):
- Whether run forward step only. Default is false. If true, no backward will be run.
- return_loss (bool, optional): Whether returns the loss value. Default is true.
- return_output_label (bool, optional): If False, the output and label won't be returned.
-
- Returns:
- Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
- The loss would be returned only in the last stage.
- """
- assert forward_only or return_loss, \
- 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
- self.load_batch(data_iter)
- model = engine.model
- input_objs = [[] for _ in range(len(model))]
- output_objs = [[] for _ in range(len(model))]
- return_tensors = []
- if not forward_only:
- output_obj_grads = [[] for _ in range(len(model))]
- if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
- accum_loss = torch.zeros(1, device=get_current_device())
- else:
- accum_loss = None
-
- # Used for obj meta information communication
- input_obj_shapes = [self.tensor_shape for _ in range(len(model))]
- output_obj_shapes = [None for _ in range(len(model))]
- send_tensor_shape_flags = [self.tensor_shape is None for _ in range(len(model))]
-
- pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
- pipeline_parallel_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
-
- # Compute number of warmup and remaining microbatches.
- num_model_chunks = len(model)
- num_microbatches = self.num_microbatches * num_model_chunks
- all_warmup_microbatches = False
- if forward_only:
- num_warmup_microbatches = num_microbatches
- else:
- # Run all forward passes and then all backward passes if number of
- # microbatches is just the number of pipeline stages.
- # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
- # all workers, followed by more microbatches after depending on
- # stage ID (more forward passes for earlier stages, later stages can
- # immediately start with 1F1B).
- if self.num_microbatches == pipeline_parallel_size:
- num_warmup_microbatches = num_microbatches
- all_warmup_microbatches = True
- else:
- num_warmup_microbatches = \
- (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
- num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
- num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
- num_microbatches_remaining = \
- num_microbatches - num_warmup_microbatches
-
- def get_model_chunk_id(microbatch_id, forward):
- """Helper method to get the model chunk ID given the iteration number."""
- microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
- model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
- if not forward:
- model_chunk_id = (num_model_chunks - model_chunk_id - 1)
- return model_chunk_id
-
- def _forward_step_helper(microbatch_id):
- """Helper method to run forward step with model split into chunks
- (run set_virtual_pipeline_model_parallel_rank() before calling
- forward_step())."""
- model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
- gpc.set_virtual_pipeline_parallel_rank(model_chunk_id)
-
- # forward step
- if gpc.is_pipeline_first_stage():
- if len(input_objs[model_chunk_id]) == \
- len(output_objs[model_chunk_id]):
- input_objs[model_chunk_id].append(None)
- input_obj = input_objs[model_chunk_id][-1]
- output_obj = self._forward_step(engine,
- model_chunk_id,
- input_obj,
- return_tensors,
- return_output_label=return_output_label,
- accum_loss=accum_loss)
- output_objs[model_chunk_id].append(output_obj)
-
- # if forward-only, no need to save tensors for a backward pass
- if forward_only:
- input_objs[model_chunk_id].pop()
- output_objs[model_chunk_id].pop()
-
- return output_obj
-
- def _backward_step_helper(microbatch_id):
- """Helper method to run backward step with model split into chunks
- (run set_virtual_pipeline_model_parallel_rank() before calling
- backward_step())."""
- model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
- gpc.set_virtual_pipeline_parallel_rank(model_chunk_id)
-
- if gpc.is_pipeline_last_stage():
- if len(output_obj_grads[model_chunk_id]) == 0:
- output_obj_grads[model_chunk_id].append(None)
- input_obj = input_objs[model_chunk_id].pop(0)
- output_obj = output_objs[model_chunk_id].pop(0)
- output_obj_grad = output_obj_grads[model_chunk_id].pop(0)
- input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
-
- return input_obj_grad
-
- # Run warmup forward passes.
- gpc.set_virtual_pipeline_parallel_rank(0)
- if not gpc.is_pipeline_first_stage():
- input_obj_shapes[0] = comm.recv_obj_meta(input_obj_shapes[0])
- input_objs[0].append(
- comm.recv_forward(input_obj_shapes[0], dtype=self.dtype,
- scatter_gather_tensors=self.scatter_gather_tensors))
-
- for k in range(num_warmup_microbatches):
- model_chunk_id = get_model_chunk_id(k, forward=True)
- output_obj = _forward_step_helper(k)
- if not gpc.is_pipeline_last_stage():
- if isinstance(output_obj, torch.Tensor):
- output_obj_shapes[model_chunk_id] = output_obj.shape
- else:
- output_obj_shapes[model_chunk_id] = []
- for out_tensor in output_obj:
- output_obj_shapes[model_chunk_id].append(out_tensor.shape)
- send_tensor_shape_flags[model_chunk_id] = comm.send_obj_meta(output_obj,
- send_tensor_shape_flags[model_chunk_id])
- # Determine if tensor should be received from previous stage.
- next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
- recv_prev = True
- if gpc.is_pipeline_first_stage(ignore_virtual=True):
- if next_forward_model_chunk_id == 0:
- recv_prev = False
- if k == (num_microbatches - 1):
- recv_prev = False
-
- # Don't send tensor downstream if on last stage.
- if gpc.is_pipeline_last_stage():
- output_obj = None
-
- with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id):
- if not gpc.is_pipeline_first_stage():
- input_obj_shapes[next_forward_model_chunk_id] = comm.recv_obj_meta(
- input_obj_shapes[next_forward_model_chunk_id])
- # Send and receive tensors as appropriate (send tensors computed
- # in this iteration; receive tensors for next iteration).
- input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None
- if k == (num_warmup_microbatches - 1) and not forward_only and \
- not all_warmup_microbatches:
- input_obj_grad = None
- recv_next = True
- if gpc.is_pipeline_last_stage(ignore_virtual=True):
- recv_next = False
- output_shape = output_obj_shapes[num_model_chunks - 1] if recv_next else None
- input_obj, output_obj_grad = \
- comm.send_forward_backward_recv_forward_backward(
- output_obj, input_obj_grad,
- input_shape,
- output_shape,
- recv_prev=recv_prev, recv_next=recv_next,
- dtype=self.dtype,
- scatter_gather_tensors=self.scatter_gather_tensors)
- output_obj_grads[num_model_chunks - 1].append(output_obj_grad)
- else:
- input_obj = \
- comm.send_forward_recv_forward(
- output_obj,
- input_shape,
- recv_prev=recv_prev,
- dtype=self.dtype,
- scatter_gather_tensors=self.scatter_gather_tensors)
- input_objs[next_forward_model_chunk_id].append(input_obj)
-
- # Run 1F1B in steady state.
- for k in range(num_microbatches_remaining):
- # Forward pass.
- forward_k = k + num_warmup_microbatches
- output_obj = _forward_step_helper(forward_k)
-
- # Backward pass.
- backward_k = k
- input_obj_grad = _backward_step_helper(backward_k)
-
- # Send output_obj and input_obj_grad, receive input_obj
- # and output_obj_grad.
-
- # Determine if current stage has anything to send in either direction,
- # otherwise set obj to None.
- forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
- gpc.set_virtual_pipeline_parallel_rank(forward_model_chunk_id)
- if gpc.is_pipeline_last_stage():
- output_obj = None
-
- backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
- gpc.set_virtual_pipeline_parallel_rank(backward_model_chunk_id)
- if gpc.is_pipeline_first_stage():
- input_obj_grad = None
-
- # Determine if peers are sending, and where in data structure to put
- # received tensors.
- recv_prev = True
- if gpc.is_pipeline_first_stage(ignore_virtual=True):
- # First stage is ahead of last stage by (pipeline_parallel_size - 1).
- next_forward_model_chunk_id = get_model_chunk_id(forward_k - (pipeline_parallel_size - 1), forward=True)
- if next_forward_model_chunk_id == (num_model_chunks - 1):
- recv_prev = False
- next_forward_model_chunk_id += 1
- else:
- next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
-
- recv_next = True
- if gpc.is_pipeline_last_stage(ignore_virtual=True):
- # Last stage is ahead of first stage by (pipeline_parallel_size - 1).
- next_backward_model_chunk_id = get_model_chunk_id(backward_k - (pipeline_parallel_size - 1),
- forward=False)
- if next_backward_model_chunk_id == 0:
- recv_next = False
- next_backward_model_chunk_id -= 1
- else:
- next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
-
- # If last iteration, don't receive; we already received one extra
- # before the start of the for loop.
- if k == (num_microbatches_remaining - 1):
- recv_prev = False
-
- input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None
- output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None
- # Communicate objs.
- input_obj, output_obj_grad = \
- comm.send_forward_backward_recv_forward_backward(
- output_obj, input_obj_grad,
- input_shape,
- output_shape,
- recv_prev=recv_prev, recv_next=recv_next,
- dtype=self.dtype,
- scatter_gather_tensors=self.scatter_gather_tensors)
-
- # Put input_obj and output_obj_grad in data structures in the
- # right location.
- if recv_prev:
- input_objs[next_forward_model_chunk_id].append(input_obj)
- if recv_next:
- output_obj_grads[next_backward_model_chunk_id].append(output_obj_grad)
-
- # Run cooldown backward passes (flush out pipeline).
- if not forward_only:
- if all_warmup_microbatches:
- output_obj_grads[num_model_chunks - 1].append(
- comm.recv_backward(output_obj_shapes[num_model_chunks - 1],
- scatter_gather_tensors=self.scatter_gather_tensors))
- for k in range(num_microbatches_remaining, num_microbatches):
- input_obj_grad = _backward_step_helper(k)
- next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
- recv_next = True
- if gpc.is_pipeline_last_stage(ignore_virtual=True):
- if next_backward_model_chunk_id == (num_model_chunks - 1):
- recv_next = False
- if k == (num_microbatches - 1):
- recv_next = False
- output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None
- output_obj_grads[next_backward_model_chunk_id].append(
- comm.send_backward_recv_backward(input_obj_grad,
- output_shape,
- recv_next=recv_next,
- dtype=self.dtype,
- scatter_gather_tensors=self.scatter_gather_tensors))
-
- if len(return_tensors) > 0:
- output, label = pack_return_tensors(return_tensors)
- return output, label, accum_loss
- else:
- return None, None, accum_loss
diff --git a/colossalai/fx/_compatibility.py b/colossalai/fx/_compatibility.py
index 0444a481627356e202912c491495364134aa65dd..4d40d5badfd0e4d6dd9b1083f21fb5e584c7d78e 100644
--- a/colossalai/fx/_compatibility.py
+++ b/colossalai/fx/_compatibility.py
@@ -2,16 +2,14 @@ from typing import Callable
import torch
-TORCH_MAJOR = int(torch.__version__.split('.')[0])
-TORCH_MINOR = int(torch.__version__.split('.')[1])
+TORCH_MAJOR = int(torch.__version__.split(".")[0])
+TORCH_MINOR = int(torch.__version__.split(".")[1])
if TORCH_MAJOR == 1 and TORCH_MINOR < 12:
META_COMPATIBILITY = False
elif TORCH_MAJOR == 1 and TORCH_MINOR == 12:
- from . import _meta_regist_12
META_COMPATIBILITY = True
elif TORCH_MAJOR == 1 and TORCH_MINOR == 13:
- from . import _meta_regist_13
META_COMPATIBILITY = True
elif TORCH_MAJOR == 2:
META_COMPATIBILITY = True
@@ -36,7 +34,7 @@ def compatibility(is_backward_compatible: bool = False) -> Callable:
else:
def wrapper(*args, **kwargs):
- raise RuntimeError(f'Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}')
+ raise RuntimeError(f"Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}")
return wrapper
diff --git a/colossalai/fx/_meta_regist_12.py b/colossalai/fx/_meta_regist_12.py
index 52e8d63ae54355a0d3e27dfc6a3347a2304dc0d7..63f88682e85a77b8ec277fdc099c1689a2b7160d 100644
--- a/colossalai/fx/_meta_regist_12.py
+++ b/colossalai/fx/_meta_regist_12.py
@@ -3,7 +3,7 @@
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
# for more meta_registrations
-from typing import Callable, List, Optional, Tuple, Union
+from typing import List, Optional, Union
import torch
from torch.utils._pytree import tree_map
@@ -16,13 +16,11 @@ meta_table = {}
def register_meta(op, register_dispatcher=True):
-
def wrapper(f):
-
def add_func(op):
meta_table[op] = f
if register_dispatcher:
- name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__)
+ name = op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__
try:
meta_lib.impl(name, f)
except:
@@ -48,7 +46,6 @@ def meta_conv(
output_padding: List[int],
groups: int,
):
-
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
@@ -125,7 +122,8 @@ def meta_conv(
kernel_size[i],
stride[i],
output_padding_list[i],
- ))
+ )
+ )
else:
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
return ret_shape
@@ -159,22 +157,42 @@ def meta_conv(
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
mem_fmt = pick_memory_format()
- out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
+ out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
return out
@register_meta(aten._convolution.default)
-def meta_conv_1(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
- padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
- *extra_args):
+def meta_conv_1(
+ input_tensor: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor,
+ stride: List[int],
+ padding: List[int],
+ dilation: List[int],
+ is_transposed: bool,
+ output_padding: List[int],
+ groups: int,
+ *extra_args,
+):
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
return out
@register_meta(aten.convolution_backward.default)
-def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
- padding, dilation, transposed, output_padding, groups, output_mask):
- return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta')
+def meta_conv_backward(
+ grad_output: torch.Tensor,
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ bias_sizes,
+ stride,
+ padding,
+ dilation,
+ transposed,
+ output_padding,
+ groups,
+ output_mask,
+):
+ return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device="meta")
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
@@ -208,7 +226,6 @@ def meta_cuda_rnn(
batch_sizes,
dropout_state,
):
-
is_input_packed = len(batch_sizes) != 0
if is_input_packed:
seq_length = len(batch_sizes)
@@ -224,8 +241,11 @@ def meta_cuda_rnn(
if is_input_packed:
out_shape = [batch_sizes_sum, out_size * num_directions]
else:
- out_shape = ([mini_batch, seq_length, out_size *
- num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
+ out_shape = (
+ [mini_batch, seq_length, out_size * num_directions]
+ if batch_first
+ else [seq_length, mini_batch, out_size * num_directions]
+ )
output = input.new_empty(out_shape)
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
@@ -242,18 +262,20 @@ def meta_cuda_rnn(
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
@register_meta(aten._cudnn_rnn_backward.default)
-def meta_cudnn_rnn_backward(input: torch.Tensor,
- weight: torch.Tensor,
- weight_stride0: int,
- hx: torch.Tensor,
- cx: Optional[torch.Tensor] = None,
- *args,
- **kwargs):
+def meta_cudnn_rnn_backward(
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ weight_stride0: int,
+ hx: torch.Tensor,
+ cx: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+):
print(input, weight, hx, cx)
grad_input = torch.empty_like(input)
grad_weight = torch.empty_like(weight)
grad_hx = torch.empty_like(hx)
- grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device='meta')
+ grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device="meta")
return grad_input, grad_weight, grad_hx, grad_cx
@@ -298,15 +320,25 @@ def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, traini
n_input = input.size(1)
output = torch.empty_like(input)
- running_mean = torch.empty((n_input), device='meta')
- running_var = torch.empty((n_input), device='meta')
+ running_mean = torch.empty((n_input), device="meta")
+ running_var = torch.empty((n_input), device="meta")
return output, running_mean, running_var
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.native_batch_norm_backward.default)
-def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean,
- save_invstd, train, eps, output_mask):
+def meta_bn_backward(
+ dY: torch.Tensor,
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ running_mean,
+ running_var,
+ save_mean,
+ save_invstd,
+ train,
+ eps,
+ output_mask,
+):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(weight)
@@ -319,9 +351,9 @@ def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var,
n_input = input.size(1)
output = torch.empty_like(input)
- running_mean = torch.empty((n_input), device='meta')
- running_var = torch.empty((n_input), device='meta')
- reserve = torch.empty((0), dtype=torch.uint8, device='meta')
+ running_mean = torch.empty((n_input), device="meta")
+ running_var = torch.empty((n_input), device="meta")
+ reserve = torch.empty((0), dtype=torch.uint8, device="meta")
return output, running_mean, running_var, reserve
@@ -330,8 +362,17 @@ def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var,
# in training mode (evaluation mode batchnorm has a different algorithm),
# which is why this doesn't accept a 'training' parameter.
@register_meta(aten.cudnn_batch_norm_backward.default)
-def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
- save_mean, save_invstd, eps, reserve):
+def meta_cudnn_bn_backward(
+ dY: torch.Tensor,
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ running_mean,
+ running_var,
+ save_mean,
+ save_invstd,
+ eps,
+ reserve,
+):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(weight)
@@ -345,15 +386,16 @@ def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
n_input = input.size(1)
output = torch.empty_like(input)
- running_mean = torch.empty((bs, n_input, 1), device='meta')
- running_var = torch.empty((bs, n_input, 1), device='meta')
+ running_mean = torch.empty((bs, n_input, 1), device="meta")
+ running_var = torch.empty((bs, n_input, 1), device="meta")
return output, running_mean, running_var
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm_backward.default)
-def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
- grad_input_mask):
+def meta_ln_backward(
+ dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, grad_input_mask
+):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(bias)
@@ -397,16 +439,19 @@ def meta_index_Tensor(self, indices):
result: List[Optional[torch.Tensor]] = []
for i, index in enumerate(indices):
if index is not None:
- assert index.dtype in [torch.long, torch.int8, torch.bool],\
- "tensors used as indices must be long, byte or bool tensors"
+ assert index.dtype in [
+ torch.long,
+ torch.int8,
+ torch.bool,
+ ], "tensors used as indices must be long, byte or bool tensors"
if index.dtype in [torch.int8, torch.bool]:
nonzero = index.nonzero()
k = len(result)
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
for j in range(index.ndim):
- assert index.shape[j] == self.shape[
- k +
- j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
+ assert (
+ index.shape[j] == self.shape[k + j]
+ ), f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
result.append(nonzero.select(1, j))
else:
result.append(index)
@@ -482,12 +527,15 @@ def meta_index_Tensor(self, indices):
# ============================== Embedding =========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
@register_meta(aten.embedding_dense_backward.default)
-def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
- scale_grad_by_freq):
- return torch.empty((num_weights, grad_output.size(-1)),
- dtype=grad_output.dtype,
- device=grad_output.device,
- layout=grad_output.layout)
+def meta_embedding_dense_backward(
+ grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, scale_grad_by_freq
+):
+ return torch.empty(
+ (num_weights, grad_output.size(-1)),
+ dtype=grad_output.dtype,
+ device=grad_output.device,
+ layout=grad_output.layout,
+ )
# ============================== Dropout ===========================================
diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py
index 5a72cb9ca923bcdf5f9b3daddad9b5e7c339d5c5..dfb5754d71c17e17386f265526cb4ff02f19d173 100644
--- a/colossalai/fx/codegen/activation_checkpoint_codegen.py
+++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py
@@ -1,4 +1,4 @@
-from typing import Any, Callable, Dict, Iterable, List, Tuple
+from typing import Any, Dict, Iterable, List, Tuple
import torch
@@ -18,6 +18,7 @@ try:
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
+
CODEGEN_AVAILABLE = True
except:
from torch.fx.graph import (
@@ -32,12 +33,13 @@ except:
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
+
CODEGEN_AVAILABLE = False
if CODEGEN_AVAILABLE:
- __all__ = ['ActivationCheckpointCodeGen']
+ __all__ = ["ActivationCheckpointCodeGen"]
else:
- __all__ = ['python_code_with_activation_checkpoint']
+ __all__ = ["python_code_with_activation_checkpoint"]
def _gen_saved_tensors_hooks():
@@ -125,15 +127,14 @@ def _find_ckpt_regions(nodes: List[Node]):
Find the checkpoint regions given a list of consecutive nodes. The outputs will be list
of tuples, each tuple is in the form of (start_index, end_index).
"""
- ckpt_nodes = []
ckpt_regions = []
start = -1
end = -1
current_region = None
for idx, node in enumerate(nodes):
- if 'activation_checkpoint' in node.meta:
- act_ckpt_label = node.meta['activation_checkpoint']
+ if "activation_checkpoint" in node.meta:
+ act_ckpt_label = node.meta["activation_checkpoint"]
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
@@ -150,7 +151,7 @@ def _find_ckpt_regions(nodes: List[Node]):
current_region = act_ckpt_label
start = idx
end = -1
- elif current_region is not None and not 'activation_checkpoint' in node.meta:
+ elif current_region is not None and not "activation_checkpoint" in node.meta:
# used to check the case below
# node ckpt states = [ckpt, ckpt, non-ckpt]
end = idx - 1
@@ -178,8 +179,8 @@ def _find_offload_regions(nodes: List[Node]):
current_region = None
for idx, node in enumerate(nodes):
- if 'activation_offload' in node.meta and isinstance(node.meta['activation_offload'], Iterable):
- act_offload_label = node.meta['activation_offload']
+ if "activation_offload" in node.meta and isinstance(node.meta["activation_offload"], Iterable):
+ act_offload_label = node.meta["activation_offload"]
if current_region == None:
current_region = act_offload_label
@@ -226,9 +227,9 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen
"""
Generate the checkpoint function call code text
"""
- outputs = ', '.join(output_vars)
- inputs = ', '.join(input_vars)
- return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})'
+ outputs = ", ".join(output_vars)
+ inputs = ", ".join(input_vars)
+ return f"{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})"
def _end_of_ckpt(node: Node, check_idx: int) -> bool:
@@ -240,9 +241,9 @@ def _end_of_ckpt(node: Node, check_idx: int) -> bool:
Returns:
bool
"""
- if 'activation_checkpoint' in node.meta:
- if isinstance(node.meta['activation_checkpoint'], list):
- return node.meta['activation_checkpoint'][check_idx] == None
+ if "activation_checkpoint" in node.meta:
+ if isinstance(node.meta["activation_checkpoint"], list):
+ return node.meta["activation_checkpoint"][check_idx] == None
else:
return False
else:
@@ -260,11 +261,11 @@ def _find_nested_ckpt_regions(nodes, check_idx=0):
current_region = None
for idx, node in enumerate(nodes):
- if 'activation_checkpoint' in node.meta:
- if isinstance(node.meta['activation_checkpoint'], int):
- act_ckpt_label = node.meta['activation_checkpoint']
+ if "activation_checkpoint" in node.meta:
+ if isinstance(node.meta["activation_checkpoint"], int):
+ act_ckpt_label = node.meta["activation_checkpoint"]
else:
- act_ckpt_label = node.meta['activation_checkpoint'][check_idx]
+ act_ckpt_label = node.meta["activation_checkpoint"][check_idx]
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
@@ -298,13 +299,9 @@ def _find_nested_ckpt_regions(nodes, check_idx=0):
return ckpt_regions
-def emit_ckpt_func(body,
- ckpt_func,
- node_list: List[Node],
- emit_node_func,
- delete_unused_value_func,
- level=0,
- in_ckpt=False):
+def emit_ckpt_func(
+ body, ckpt_func, node_list: List[Node], emit_node_func, delete_unused_value_func, level=0, in_ckpt=False
+):
"""Emit ckpt function in nested way
Args:
body: forward code, in recursive calls, this part will be checkpoint
@@ -321,17 +318,17 @@ def emit_ckpt_func(body,
inputs, outputs = _find_input_and_output_nodes(node_list)
# if the current checkpoint function use int as label, using old generation method
- if isinstance(node_list[0].meta['activation_checkpoint'], int):
- label = node_list[0].meta['activation_checkpoint']
+ if isinstance(node_list[0].meta["activation_checkpoint"], int):
+ label = node_list[0].meta["activation_checkpoint"]
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
- ckpt_func.append(f'{ckpt_fn_def}\n')
+ ckpt_func.append(f"{ckpt_fn_def}\n")
for node in node_list:
emit_node_func(node, ckpt_func)
- ckpt_func[-1] = ' ' + ckpt_func[-1]
+ ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
- ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
- activation_offload = node_list[0].meta.get('activation_offload', False)
+ ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
+ activation_offload = node_list[0].meta.get("activation_offload", False)
usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False)
usage += "\n"
body.append(usage)
@@ -340,12 +337,12 @@ def emit_ckpt_func(body,
else:
# label given by each layer, e.g. if you are currently at level [0, 1, 1]
# the label will be '0_1_1'
- label = "_".join([str(idx) for idx in node_list[0].meta['activation_checkpoint'][:level + 1]])
+ label = "_".join([str(idx) for idx in node_list[0].meta["activation_checkpoint"][: level + 1]])
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
- ckpt_func.append(f'{ckpt_fn_def}\n')
+ ckpt_func.append(f"{ckpt_fn_def}\n")
# if there is more level to fetch
- if level + 1 < len(node_list[0].meta['activation_checkpoint']):
+ if level + 1 < len(node_list[0].meta["activation_checkpoint"]):
ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
@@ -358,38 +355,45 @@ def emit_ckpt_func(body,
break
if node_idx in start_idx:
- ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
- emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func,
- delete_unused_value_func, level + 1, True)
+ ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
+ emit_ckpt_func(
+ ckpt_func,
+ ckpt_func_buffer,
+ ckpt_node_list,
+ emit_node_func,
+ delete_unused_value_func,
+ level + 1,
+ True,
+ )
node_idx += len(ckpt_node_list)
else:
node = node_list[node_idx]
emit_node_func(node, ckpt_func)
- ckpt_func[-1] = ' ' + ckpt_func[-1]
+ ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
node_idx += 1
- ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
+ ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
ckpt_func += ckpt_func_buffer
- activation_offload = node_list[0].meta.get('activation_offload', False)
- usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
+ activation_offload = node_list[0].meta.get("activation_offload", False)
+ usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + "\n"
if in_ckpt:
- usage = ' ' + usage
+ usage = " " + usage
body.append(usage)
# last level
else:
for node in node_list:
emit_node_func(node, ckpt_func)
- ckpt_func[-1] = ' ' + ckpt_func[-1]
+ ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
- ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
- activation_offload = node_list[0].meta.get('activation_offload', False)
- usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n'
+ ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
+ activation_offload = node_list[0].meta.get("activation_offload", False)
+ usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + "\n"
if in_ckpt:
- usage = ' ' + usage
+ usage = " " + usage
body.append(usage)
@@ -420,7 +424,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
# find the input and output var names for each offload region
for idx, (start, end) in enumerate(offload_regions):
- offload_node_list = node_list[start:end + 1]
+ offload_node_list = node_list[start : end + 1]
inputs, outputs = _find_input_and_output_nodes(offload_node_list)
offload_inputs.append(inputs)
offload_outputs.append(outputs)
@@ -436,7 +440,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
# process ckpt_regions
if node_idx in start_idx:
- ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
+ ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func)
node_idx += len(ckpt_node_list)
@@ -470,7 +474,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
if within_offload_region:
emit_node_func(node, body)
- body[-1] = ' ' + body[-1]
+ body[-1] = " " + body[-1]
delete_unused_value_func(node, body)
else:
@@ -508,14 +512,14 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# find the input and output var names for each region
for idx, (start, end) in enumerate(ckpt_regions):
- ckpt_node_list = node_list[start:end + 1]
+ ckpt_node_list = node_list[start : end + 1]
inputs, outputs = _find_input_and_output_nodes(ckpt_node_list)
input_vars.append(inputs)
output_vars.append(outputs)
# find the input and output var names for each offload region
for idx, (start, end) in enumerate(offload_regions):
- offload_node_list = node_list[start:end + 1]
+ offload_node_list = node_list[start : end + 1]
inputs, outputs = _find_input_and_output_nodes(offload_node_list)
offload_inputs.append(inputs)
offload_outputs.append(outputs)
@@ -523,11 +527,11 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# append code text to body
for idx, node in enumerate(node_list):
# if this is the first node of the ckpt region
- # append the ckpt function defition
+ # append the ckpt function definition
if idx in start_idx:
label = start_idx.index(idx)
ckpt_fn_def = _gen_ckpt_fn_def(label, input_vars[label])
- ckpt_func.append(f'{ckpt_fn_def}\n')
+ ckpt_func.append(f"{ckpt_fn_def}\n")
within_ckpt_region = True
if idx in offload_starts:
@@ -559,12 +563,12 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# NOTE: currently we separate body and ckpt_func definition
if within_ckpt_region:
emit_node_func(node, ckpt_func)
- ckpt_func[-1] = ' ' + ckpt_func[-1]
+ ckpt_func[-1] = " " + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
elif within_offload_region:
emit_node_func(node, body)
- body[-1] = ' ' + body[-1]
+ body[-1] = " " + body[-1]
delete_unused_value_func(node, body)
else:
@@ -576,13 +580,13 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# generate return statement
label = end_idx.index(idx)
return_statement = _gen_ckpt_output(output_vars[label])
- return_statement = f' {return_statement}\n\n'
+ return_statement = f" {return_statement}\n\n"
ckpt_func.append(return_statement)
# we need to check if the checkpoint need to offload the input
start_node_idx = start_idx[label]
- if 'activation_offload' in node_list[start_node_idx].meta:
- activation_offload = node_list[start_node_idx].meta['activation_offload']
+ if "activation_offload" in node_list[start_node_idx].meta:
+ activation_offload = node_list[start_node_idx].meta["activation_offload"]
else:
activation_offload = False
@@ -594,8 +598,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
if input_node.op != "placeholder":
non_leaf_input = 1
for user in input_node.users:
- if 'activation_checkpoint' in user.meta:
- if user.meta['activation_checkpoint'] == label:
+ if "activation_checkpoint" in user.meta:
+ if user.meta["activation_checkpoint"] == label:
if user.op == "call_module":
if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"):
use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace
@@ -610,7 +614,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# generate checkpoint function call in a new line
usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label], use_reentrant)
- usage += '\n'
+ usage += "\n"
body.append(usage)
within_ckpt_region = False
@@ -621,7 +625,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
if CODEGEN_AVAILABLE:
class ActivationCheckpointCodeGen(CodeGen):
-
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
free_vars: List[str] = []
body: List[str] = []
@@ -629,7 +632,7 @@ if CODEGEN_AVAILABLE:
wrapped_fns: Dict[str, None] = {}
# Wrap string in list to pass by reference
- maybe_return_annotation: List[str] = ['']
+ maybe_return_annotation: List[str] = [""]
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
@@ -637,7 +640,7 @@ if CODEGEN_AVAILABLE:
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
- if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
+ if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
@@ -662,16 +665,16 @@ if CODEGEN_AVAILABLE:
def type_repr(o: Any):
if o == ():
# Empty tuple is used for empty tuple type annotation Tuple[()]
- return '()'
+ return "()"
typename = _type_repr(o)
- if hasattr(o, '__origin__'):
+ if hasattr(o, "__origin__"):
# This is a generic type, e.g. typing.List[torch.Tensor]
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
origin_typename = add_global(_type_repr(origin_type), origin_type)
- if hasattr(o, '__args__'):
+ if hasattr(o, "__args__"):
# Assign global names for each of the inner type variables.
args = [type_repr(arg) for arg in o.__args__]
@@ -690,19 +693,18 @@ if CODEGEN_AVAILABLE:
return add_global(typename, o)
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
-
def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global.
- if isinstance(arg, tuple) and hasattr(arg, '_fields'):
+ if isinstance(arg, tuple) and hasattr(arg, "_fields"):
qualified_name = _get_qualified_name(type(arg))
global_name = add_global(qualified_name, type(arg))
return f"{global_name}{repr(tuple(arg))}"
return repr(arg)
- args_s = ', '.join(_get_repr(a) for a in args)
- kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
+ args_s = ", ".join(_get_repr(a) for a in args)
+ kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
if args_s and kwargs_s:
- return f'{args_s}, {kwargs_s}'
+ return f"{args_s}, {kwargs_s}"
return args_s or kwargs_s
# Run through reverse nodes and record the first instance of a use
@@ -728,90 +730,101 @@ if CODEGEN_AVAILABLE:
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
- if user.op == 'placeholder':
+ if user.op == "placeholder":
return
- if user.op == 'output':
- body.append('\n')
+ if user.op == "output":
+ body.append("\n")
return
nodes_to_delete = user_to_last_uses.get(user, [])
if len(nodes_to_delete):
- to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
- body.append(f'; {to_delete_str}\n')
+ to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"])
+ body.append(f"; {to_delete_str}\n")
else:
- body.append('\n')
+ body.append("\n")
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
- maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
- if node.op == 'placeholder':
+ maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
+ if node.op == "placeholder":
assert isinstance(node.target, str)
- maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
- free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
- raw_name = node.target.replace('*', '')
+ maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
+ free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
+ raw_name = node.target.replace("*", "")
if raw_name != repr(node):
- body.append(f'{repr(node)} = {raw_name}\n')
+ body.append(f"{repr(node)} = {raw_name}\n")
return
- elif node.op == 'call_method':
+ elif node.op == "call_method":
assert isinstance(node.target, str)
body.append(
- f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
- f'({_format_args(node.args[1:], node.kwargs)})')
+ f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
+ f"({_format_args(node.args[1:], node.kwargs)})"
+ )
return
- elif node.op == 'call_function':
+ elif node.op == "call_function":
assert callable(node.target)
# pretty print operators
- if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
+ if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
- body.append(f'{repr(node)}{maybe_type_annotation} = '
- f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = "
+ f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
+ )
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
- if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods:
- body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; '
- f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}')
+ if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods:
+ body.append(
+ f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
+ f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
+ )
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
- if global_name == 'getattr' and \
- isinstance(node.args, tuple) and \
- isinstance(node.args[1], str) and \
- node.args[1].isidentifier() and \
- len(node.args) == 2:
+ if (
+ global_name == "getattr"
+ and isinstance(node.args, tuple)
+ and isinstance(node.args[1], str)
+ and node.args[1].isidentifier()
+ and len(node.args) == 2
+ ):
body.append(
- f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
+ f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
+ )
return
body.append(
- f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
- if node.meta.get('is_wrapped', False):
+ f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
+ )
+ if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name)
return
- elif node.op == 'call_module':
+ elif node.op == "call_module":
assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = '
- f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = "
+ f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
+ )
return
- elif node.op == 'get_attr':
+ elif node.op == "get_attr":
assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
+ body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}")
return
- elif node.op == 'output':
+ elif node.op == "output":
if node.type is not None:
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
body.append(self.generate_output(node.args[0]))
return
- raise NotImplementedError(f'node: {node.op} {node.target}')
+ raise NotImplementedError(f"node: {node.op} {node.target}")
# Modified for activation checkpointing
ckpt_func = []
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
- if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in nodes):
+ if any(isinstance(node.meta.get("activation_checkpoint", None), Iterable) for node in nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
else:
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
@@ -820,13 +833,13 @@ if CODEGEN_AVAILABLE:
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
- body.append('pass\n')
+ body.append("pass\n")
if len(wrapped_fns) > 0:
- wrap_name = add_global('wrap', torch.fx.wrap)
- wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
+ wrap_name = add_global("wrap", torch.fx.wrap)
+ wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
else:
- wrap_stmts = ''
+ wrap_stmts = ""
if self._body_transformer:
body = self._body_transformer(body)
@@ -837,11 +850,11 @@ if CODEGEN_AVAILABLE:
# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
- prologue = ''.join(ckpt_func) + prologue
+ prologue = "".join(ckpt_func) + prologue
prologue = prologue
- code = ''.join(body)
- code = '\n'.join(' ' + line for line in code.split('\n'))
+ code = "".join(body)
+ code = "\n".join(" " + line for line in code.split("\n"))
fn_code = f"""
{wrap_stmts}
{prologue}
@@ -861,7 +874,7 @@ else:
wrapped_fns: Dict[str, None] = {}
# Wrap string in list to pass by reference
- maybe_return_annotation: List[str] = ['']
+ maybe_return_annotation: List[str] = [""]
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
@@ -869,7 +882,7 @@ else:
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
- if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
+ if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
@@ -894,12 +907,12 @@ else:
def type_repr(o: Any):
if o == ():
# Empty tuple is used for empty tuple type annotation Tuple[()]
- return '()'
+ return "()"
typename = _type_repr(o)
# This is a generic type, e.g. typing.List[torch.Tensor]
- if hasattr(o, '__origin__'):
+ if hasattr(o, "__origin__"):
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
origin_typename = add_global(_type_repr(origin_type), origin_type)
@@ -934,84 +947,94 @@ else:
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
- if user.op == 'placeholder':
+ if user.op == "placeholder":
return
- if user.op == 'output':
- body.append('\n')
+ if user.op == "output":
+ body.append("\n")
return
nodes_to_delete = user_to_last_uses.get(user, [])
if len(nodes_to_delete):
- to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
- body.append(f'; {to_delete_str}\n')
+ to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"])
+ body.append(f"; {to_delete_str}\n")
else:
- body.append('\n')
+ body.append("\n")
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
- maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
- if node.op == 'placeholder':
+ maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
+ if node.op == "placeholder":
assert isinstance(node.target, str)
- maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
- free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
- raw_name = node.target.replace('*', '')
+ maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
+ free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
+ raw_name = node.target.replace("*", "")
if raw_name != repr(node):
- body.append(f'{repr(node)} = {raw_name}\n')
+ body.append(f"{repr(node)} = {raw_name}\n")
return
- elif node.op == 'call_method':
+ elif node.op == "call_method":
assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
- f'({_format_args(node.args[1:], node.kwargs)})')
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
+ f"({_format_args(node.args[1:], node.kwargs)})"
+ )
return
- elif node.op == 'call_function':
+ elif node.op == "call_function":
assert callable(node.target)
# pretty print operators
- if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
+ if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
- body.append(f'{repr(node)}{maybe_type_annotation} = '
- f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = "
+ f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
+ )
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
- if global_name == 'getattr' and \
- isinstance(node.args, tuple) and \
- isinstance(node.args[1], str) and \
- node.args[1].isidentifier() and \
- len(node.args) == 2:
+ if (
+ global_name == "getattr"
+ and isinstance(node.args, tuple)
+ and isinstance(node.args[1], str)
+ and node.args[1].isidentifier()
+ and len(node.args) == 2
+ ):
body.append(
- f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
+ f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
+ )
return
body.append(
- f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
- if node.meta.get('is_wrapped', False):
+ f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
+ )
+ if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name)
return
- elif node.op == 'call_module':
+ elif node.op == "call_module":
assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = '
- f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = "
+ f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
+ )
return
- elif node.op == 'get_attr':
+ elif node.op == "get_attr":
assert isinstance(node.target, str)
- body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
+ body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}")
return
- elif node.op == 'output':
+ elif node.op == "output":
if node.type is not None:
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
if self._pytree_info is None:
- body.append(f'return {repr(node.args[0])}')
+ body.append(f"return {repr(node.args[0])}")
else:
- body.append(f'return pytree.tree_unflatten({repr(node.args[0])}, self._out_spec)')
+ body.append(f"return pytree.tree_unflatten({repr(node.args[0])}, self._out_spec)")
return
- raise NotImplementedError(f'node: {node.op} {node.target}')
+ raise NotImplementedError(f"node: {node.op} {node.target}")
# Modified for activation checkpointing
ckpt_func = []
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
- if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in self.nodes):
+ if any(isinstance(node.meta.get("activation_checkpoint", None), Iterable) for node in self.nodes):
emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
else:
emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values)
@@ -1020,33 +1043,34 @@ else:
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
- body.append('pass\n')
+ body.append("pass\n")
if self._pytree_info is not None:
orig_args = self._pytree_info.orig_args
- has_orig_self = (orig_args[0] == 'self')
+ has_orig_self = orig_args[0] == "self"
if has_orig_self:
- free_vars.insert(0, 'self')
- if len(free_vars) > 0: # pytree has placeholders in it
+ free_vars.insert(0, "self")
+ if len(free_vars) > 0: # pytree has placeholders in it
body.insert(
0,
- f"{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(orig_args)}], self._in_spec)\n")
+ f"{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(orig_args)}], self._in_spec)\n",
+ )
else:
orig_args = free_vars
if len(wrapped_fns) > 0:
- wrap_name = add_global('wrap', torch.fx.wrap)
- wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
+ wrap_name = add_global("wrap", torch.fx.wrap)
+ wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
else:
- wrap_stmts = ''
+ wrap_stmts = ""
- ckpt_func = ''.join(ckpt_func)
+ ckpt_func = "".join(ckpt_func)
# If the original function didn't have self as its first argument, we
# would have added it.
- if len(orig_args) == 0 or orig_args[0] != 'self':
- orig_args.insert(0, 'self')
- code = ''.join(body)
- code = '\n'.join(' ' + line for line in code.split('\n'))
+ if len(orig_args) == 0 or orig_args[0] != "self":
+ orig_args.insert(0, "self")
+ code = "".join(body)
+ code = "\n".join(" " + line for line in code.split("\n"))
# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
diff --git a/colossalai/fx/graph_module.py b/colossalai/fx/graph_module.py
index ebb9975f27dbf312870fdb9d7beea98faf7889b7..8429a9607f7a7aca4d15f89a2a6207dbd94083e4 100644
--- a/colossalai/fx/graph_module.py
+++ b/colossalai/fx/graph_module.py
@@ -1,32 +1,35 @@
import os
import warnings
from pathlib import Path
-from typing import Any, Dict, List, Optional, Set, Type, Union
+from typing import Any, Dict, Optional, Union
import torch
import torch.nn as nn
from torch.nn.modules.module import _addindent
try:
- from torch.fx.graph import Graph, PythonCode, _custom_builtins, _is_from_torch, _PyTreeCodeGen
- from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _exec_with_source, _forward_from_src, _WrappedCall
+ from torch.fx.graph import Graph, PythonCode, _PyTreeCodeGen
+ from torch.fx.graph_module import GraphModule, _exec_with_source, _forward_from_src, _WrappedCall
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
+
COLOGM = True
except:
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
+
COLOGM = False
if COLOGM:
class ColoGraphModule(GraphModule):
-
- def __init__(self,
- root: Union[torch.nn.Module, Dict[str, Any]],
- graph: Graph,
- class_name: str = 'GraphModule',
- ckpt_codegen: bool = True):
+ def __init__(
+ self,
+ root: Union[torch.nn.Module, Dict[str, Any]],
+ graph: Graph,
+ class_name: str = "GraphModule",
+ ckpt_codegen: bool = True,
+ ):
if ckpt_codegen:
graph.set_codegen(ActivationCheckpointCodeGen())
super().__init__(root, graph, class_name)
@@ -60,7 +63,7 @@ if COLOGM:
if isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
- python_code = self._graph.python_code(root_module='self')
+ python_code = self._graph.python_code(root_module="self")
self._code = python_code.src
# To split ckpt functions code and forward code
@@ -83,8 +86,8 @@ if COLOGM:
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
cls_call = cls.__call__ if "__call__" in vars(cls) else None
- if '_wrapped_call' not in vars(cls):
- cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
+ if "_wrapped_call" not in vars(cls):
+ cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
def call_wrapped(self, *args, **kwargs):
return self._wrapped_call(self, *args, **kwargs)
@@ -108,7 +111,7 @@ if COLOGM:
"""
folder = Path(folder)
Path(folder).mkdir(exist_ok=True)
- torch.save(self.state_dict(), folder / 'state_dict.pt')
+ torch.save(self.state_dict(), folder / "state_dict.pt")
tab = " " * 4
# we add import colossalai here
@@ -125,7 +128,13 @@ class {module_name}(torch.nn.Module):
def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
safe_reprs = [
- nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
+ nn.Linear,
+ nn.Conv1d,
+ nn.Conv2d,
+ nn.Conv3d,
+ nn.BatchNorm1d,
+ nn.BatchNorm2d,
+ nn.BatchNorm3d,
]
if type(module) in safe_reprs:
return f"{module.__repr__()}"
@@ -136,10 +145,10 @@ class {module_name}(torch.nn.Module):
for module_name, module in self.named_children():
module_str = _gen_model_repr(module_name, module)
if module_str is None:
- module_file = folder / f'{module_name}.pt'
+ module_file = folder / f"{module_name}.pt"
torch.save(module, module_file)
blobified_modules.append(module_name)
- module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ')
+ module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
module_str = f"torch.load(r'{module_file}') # {module_repr}"
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
@@ -156,19 +165,20 @@ class {module_name}(torch.nn.Module):
model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
model_str += f"{_addindent(self.code, 4)}\n"
- module_file = folder / 'module.py'
+ module_file = folder / "module.py"
module_file.write_text(model_str)
- init_file = folder / '__init__.py'
- init_file.write_text('from .module import *')
+ init_file = folder / "__init__.py"
+ init_file.write_text("from .module import *")
if len(blobified_modules) > 0:
- warnings.warn("Was not able to save the following children modules as reprs -"
- f"saved as pickled files instead: {blobified_modules}")
+ warnings.warn(
+ "Was not able to save the following children modules as reprs -"
+ f"saved as pickled files instead: {blobified_modules}"
+ )
else:
class ColoGraphModule(GraphModule):
-
- def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
+ def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = "GraphModule"):
super().__init__(root, graph, class_name)
diff --git a/colossalai/fx/passes/adding_split_node_pass.py b/colossalai/fx/passes/adding_split_node_pass.py
index 2c7b842b530cc12fb669154456dc51e489ccc85d..99c8faaa0cc6927b3aaf7649b3048a071bbeb2dd 100644
--- a/colossalai/fx/passes/adding_split_node_pass.py
+++ b/colossalai/fx/passes/adding_split_node_pass.py
@@ -1,8 +1,6 @@
import numpy as np
import torch
import tqdm
-from torch.fx import symbolic_trace
-from torch.fx.node import Node
from colossalai.fx.passes.split_module import split_module
@@ -29,15 +27,15 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01):
accumulate_bwd_flop = 0
block_nodes = []
for node in gm.graph.nodes:
- if 'block_split' in node.name:
+ if "block_split" in node.name:
continue
accumulate_fwd_flop += node.fwd_flop
accumulate_bwd_flop += node.bwd_flop
if accumulate_fwd_flop + accumulate_bwd_flop >= per_block_flop:
with gm.graph.inserting_after(node):
- block_node = gm.graph.create_node('call_function', block_split)
- setattr(block_node, 'fwd_flop', accumulate_fwd_flop)
- setattr(block_node, 'bwd_flop', accumulate_bwd_flop)
+ block_node = gm.graph.create_node("call_function", block_split)
+ setattr(block_node, "fwd_flop", accumulate_fwd_flop)
+ setattr(block_node, "bwd_flop", accumulate_bwd_flop)
accumulate_fwd_flop = 0
accumulate_bwd_flop = 0
block_nodes.append(block_node)
@@ -47,7 +45,7 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01):
def remove_blocks(gm: torch.fx.GraphModule):
for node in gm.graph.nodes:
- if (node.op, node.target) == ('call_function', block_split):
+ if (node.op, node.target) == ("call_function", block_split):
gm.graph.erase_node(node)
@@ -55,8 +53,8 @@ def get_compute_costs(node_list):
num_nodes = len(node_list)
all_compute_cost = np.full((num_nodes, num_nodes), np.inf, dtype=np.float64)
- for start in tqdm.tqdm(range(num_nodes), desc='start pos', position=0):
- for end in tqdm.tqdm(range(start, num_nodes), desc='end pos', position=1, leave=False):
+ for start in tqdm.tqdm(range(num_nodes), desc="start pos", position=0):
+ for end in tqdm.tqdm(range(start, num_nodes), desc="end pos", position=1, leave=False):
selected_flops = [(node_list[i].fwd_flop + node_list[i].bwd_flop) for i in range(start, end + 1)]
all_compute_cost[start, end] = sum(selected_flops)
@@ -78,12 +76,14 @@ def do_dp_split_gpipe_impl(num_nodes, num_stages, num_microbatches, compute_cost
# record start node index for next stage in this partition
f_argmin = np.full((num_stages + 1, num_nodes + 1), -1, dtype=np.int32)
f[0, num_nodes] = 0
- for s in tqdm.tqdm(range(1, num_stages + 1), desc='stage', position=2, leave=False): # pylint: disable=too-many-nested-blocks
- for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc='start node', position=3, leave=False):
- for k in tqdm.tqdm(range(num_nodes, i, -1), desc='mid node', position=4, leave=False):
+ for s in tqdm.tqdm(
+ range(1, num_stages + 1), desc="stage", position=2, leave=False
+ ): # pylint: disable=too-many-nested-blocks
+ for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc="start node", position=3, leave=False):
+ for k in tqdm.tqdm(range(num_nodes, i, -1), desc="mid node", position=4, leave=False):
stage_cost = compute_costs[i, k - 1]
new_cost = f[s - 1, k] + stage_cost
- if (stage_cost <= max_compute_cost and new_cost < f[s, i]):
+ if stage_cost <= max_compute_cost and new_cost < f[s, i]:
f[s, i] = new_cost
f_stage_max[s, i] = max(f_stage_max[s - 1, k], stage_cost)
f_argmin[s, i] = k
@@ -113,7 +113,7 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche
best_cost = np.inf
best_solution = None
last_max_compute_cost = 0.0
- gap = 1e6 # temporary magic number, unit: flops
+ gap = 1e6 # temporary magic number, unit: flops
for max_compute_cost in tqdm.tqdm(max_compute_costs):
# Pruning to reduce search space.
@@ -122,8 +122,9 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche
if max_compute_cost - last_max_compute_cost < gap:
continue
- cost, solution = do_dp_split_gpipe_impl(len(node_list), num_stages, num_microbatches, compute_costs,
- max_compute_cost)
+ cost, solution = do_dp_split_gpipe_impl(
+ len(node_list), num_stages, num_microbatches, compute_costs, max_compute_cost
+ )
if cost < best_cost:
best_cost = cost
@@ -137,15 +138,15 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche
# split_mode:
# 'node': fx_node
# 'block': many fx_nodes construct a block
-def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode='block', block_limit=0.01):
- assert mode in ['node', 'block']
+def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode="block", block_limit=0.01):
+ assert mode in ["node", "block"]
# nodes or blocks will be used in partition.
node_list = []
- if mode == 'node':
+ if mode == "node":
for node in gm.graph.nodes:
node_list.append(node)
- elif mode == 'block':
+ elif mode == "block":
node_list = construct_blocks(gm, limit=block_limit)
else:
pass
@@ -154,16 +155,16 @@ def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches
best_cost, best_solution = do_dp_split_gpipe(node_list, compute_costs, pp_size, num_microbatches)
- for (_, next_start_node) in best_solution:
+ for _, next_start_node in best_solution:
if pp_size <= 1:
break
node = node_list[next_start_node]
with gm.graph.inserting_before(node):
- split_node = gm.graph.create_node('call_function', pipe_split)
+ split_node = gm.graph.create_node("call_function", pipe_split)
pp_size -= 1
# remove block node if possible
- if mode == 'block':
+ if mode == "block":
remove_blocks(gm)
gm.recompile()
@@ -178,7 +179,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
# To use avgcompute_split_pass, we need run meta_info_prop interpreter first.
# If nodes don't have meta info, this pass will fall back to normal balanced split pass.
check_node = list(mod_graph.nodes)[0]
- if 'tensor_meta' not in check_node.meta:
+ if "tensor_meta" not in check_node.meta:
return balanced_split_pass(gm, pp_size)
total_fwd_flop = 0
@@ -190,7 +191,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
for node in mod_graph.nodes:
if pp_size <= 1:
break
- if 'pipe_split' in node.name:
+ if "pipe_split" in node.name:
continue
accumulate_fwd_flop += node.fwd_flop
if accumulate_fwd_flop >= partition_flop:
@@ -199,14 +200,14 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1
partition_flop = total_fwd_flop // pp_size
with mod_graph.inserting_after(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int):
"""
- In avgnode_split_pass, simpliy split graph by node number.
+ In avgnode_split_pass, simply split graph by node number.
"""
mod_graph = gm.graph
avg_num_node = len(mod_graph.nodes) // pp_size
@@ -218,12 +219,12 @@ def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int):
if accumulate_num_node >= avg_num_node:
accumulate_num_node = 0
pp_size -= 1
- if node.next.op == 'output':
+ if node.next.op == "output":
with mod_graph.inserting_before(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
else:
with mod_graph.inserting_after(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
@@ -250,18 +251,18 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1
# If the next node is output node, we will insert split annotation before
# node to make sure there is at least one node in last partition.
- if node.next.op == 'output':
+ if node.next.op == "output":
with mod_graph.inserting_before(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
else:
with mod_graph.inserting_after(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
if pp_size > 1:
node_counter = 0
for node in mod_graph.nodes:
if pp_size <= 1:
break
- if node.op == 'placeholder':
+ if node.op == "placeholder":
continue
elif node_counter == 0:
node_counter += 1
@@ -269,7 +270,7 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1
node_counter = 0
with mod_graph.inserting_before(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
@@ -283,7 +284,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
# To use balanced_split_pass_v2, we need run meta_info_prop interpreter first.
# If nodes don't have meta info, this pass will fall back to normal balanced split pass.
check_node = list(mod_graph.nodes)[0]
- if 'tensor_meta' not in check_node.meta:
+ if "tensor_meta" not in check_node.meta:
return balanced_split_pass(gm, pp_size)
total_element_size = 0
@@ -295,7 +296,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
for node in mod_graph.nodes:
if pp_size <= 1:
break
- if 'pipe_split' in node.name:
+ if "pipe_split" in node.name:
continue
accumulate_node_size += node.node_size
if accumulate_node_size >= partition_size:
@@ -304,7 +305,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
pp_size -= 1
partition_size = total_element_size // pp_size
with mod_graph.inserting_after(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
@@ -333,7 +334,7 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int):
accumulate_layer_amount = 0
pp_size -= 1
with mod_graph.inserting_after(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
@@ -346,7 +347,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output
def split_callback(n: torch.fx.Node):
nonlocal part_idx
- if (n.op, n.target) == ('call_function', pipe_split):
+ if (n.op, n.target) == ("call_function", pipe_split):
part_idx += 1
return part_idx
@@ -355,7 +356,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output
for name, submodule in split_mod.named_modules():
if isinstance(submodule, torch.fx.GraphModule):
for node in submodule.graph.nodes:
- if (node.op, node.target) == ('call_function', pipe_split):
+ if (node.op, node.target) == ("call_function", pipe_split):
submodule.graph.erase_node(node)
submodule.recompile()
split_submodules.append(submodule)
diff --git a/colossalai/fx/passes/concrete_info_prop.py b/colossalai/fx/passes/concrete_info_prop.py
index 81ac6420552815a5cea2d3fe5aef175efb36ece0..5440a4eadbbfbd336668e63a773c652323bac12e 100644
--- a/colossalai/fx/passes/concrete_info_prop.py
+++ b/colossalai/fx/passes/concrete_info_prop.py
@@ -1,5 +1,5 @@
from dataclasses import asdict
-from typing import Any, Dict, List, NamedTuple, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.fx
@@ -85,10 +85,10 @@ class ConcreteInfoProp(torch.fx.Interpreter):
self._is_proped = True
result, meta_info = super().run_node(n)
- n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
+ n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
- setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0))
- n.meta['type'] = type(result)
+ setattr(n, "node_size", n.meta.get("fwd_mem_tmp", 0) + n.meta.get("fwd_mem_out", 0))
+ n.meta["type"] = type(result)
# retain the autograd graph
for param in self.module.parameters():
@@ -98,7 +98,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
# Main Node running APIs
@compatibility(is_backward_compatible=True)
- def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def placeholder(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``placeholder`` node. Note that this is stateful:
``Interpreter`` maintains an internal iterator over
@@ -119,7 +119,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
return super().placeholder(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
- def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def get_attr(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``get_attr`` node. Will retrieve an attribute
value from the ``Module`` hierarchy of ``self.module``.
@@ -138,7 +138,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
return super().get_attr(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
- def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node with meta tensor and return the result and its meta profile.
@@ -157,7 +157,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
return profile_function(target, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
- def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node with meta tensor and return the result and its meta profile.
@@ -175,7 +175,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
return profile_method(target, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
- def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node with meta tensor and return the result and its meta profile.
@@ -197,7 +197,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
return profile_module(submod, self.device)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
- def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def output(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute an ``output`` node. This really just retrieves
the value referenced by the ``output`` node and returns it.
@@ -228,7 +228,7 @@ class ConcreteInfoProp(torch.fx.Interpreter):
"""
return self.run(*args)
- def summary(self, unit: str = 'MB') -> str:
+ def summary(self, unit: str = "MB") -> str:
"""
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
@@ -238,9 +238,11 @@ class ConcreteInfoProp(torch.fx.Interpreter):
try:
from tabulate import tabulate
except ImportError:
- print("`summary` relies on the library `tabulate`, "
- "which could not be found on this machine. Run `pip "
- "install tabulate` to install the library.")
+ print(
+ "`summary` relies on the library `tabulate`, "
+ "which could not be found on this machine. Run `pip "
+ "install tabulate` to install the library."
+ )
assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`."
@@ -249,10 +251,10 @@ class ConcreteInfoProp(torch.fx.Interpreter):
def mem_repr(mem: int) -> str:
unit_divisor_map = {
- 'kb': 1024,
- 'mb': 1024**2,
- 'gb': 1024**3,
- 'tb': 1024**4,
+ "kb": 1024,
+ "mb": 1024**2,
+ "gb": 1024**3,
+ "tb": 1024**4,
}
return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}"
@@ -261,30 +263,32 @@ class ConcreteInfoProp(torch.fx.Interpreter):
for node in self.module.graph.nodes:
node: Node
- node_summaries.append([
- node.op,
- str(node),
- time_repr(node.meta['fwd_time']),
- time_repr(node.meta['bwd_time']),
- node.meta['save_fwd_in'],
- mem_repr(node.meta['fwd_mem_out']),
- mem_repr(node.meta['fwd_mem_tmp']),
- mem_repr(node.meta['bwd_mem_out']),
- mem_repr(node.meta['bwd_mem_tmp']),
- ])
+ node_summaries.append(
+ [
+ node.op,
+ str(node),
+ time_repr(node.meta["fwd_time"]),
+ time_repr(node.meta["bwd_time"]),
+ node.meta["save_fwd_in"],
+ mem_repr(node.meta["fwd_mem_out"]),
+ mem_repr(node.meta["fwd_mem_tmp"]),
+ mem_repr(node.meta["bwd_mem_out"]),
+ mem_repr(node.meta["bwd_mem_tmp"]),
+ ]
+ )
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers: List[str] = [
- 'Op type',
- 'Op',
- 'Forward time',
- 'Backward time',
- 'SAVE_FWD_IN',
- 'FWD_OUT',
- 'FWD_TMP',
- 'BWD_OUT',
- 'BWD_TMP',
+ "Op type",
+ "Op",
+ "Forward time",
+ "Backward time",
+ "SAVE_FWD_IN",
+ "FWD_OUT",
+ "FWD_TMP",
+ "BWD_OUT",
+ "BWD_TMP",
]
- return tabulate(node_summaries, headers=headers, stralign='right')
+ return tabulate(node_summaries, headers=headers, stralign="right")
diff --git a/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py b/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py
index f28d65e2668ac39e7b189c7d181d018468648614..3d032a27db638b4aea1fff17afdb5a7117a30431 100644
--- a/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py
+++ b/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py
@@ -1,14 +1,11 @@
-import torch
-from typing import List
-from torch.fx import symbolic_trace
-from torch.fx.node import Node
-from colossalai.fx.passes.split_module import split_module
-from colossalai.tensor.shape_consistency import ShapeConsistencyManager
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
import builtins
import operator
-from copy import deepcopy
+from typing import List
+
+import torch
+
+from colossalai.tensor.shape_consistency import ShapeConsistencyManager
+from colossalai.tensor.sharding_spec import ShardingSpec
def apply(*args, **kwargs):
@@ -16,7 +13,7 @@ def apply(*args, **kwargs):
return shape_consistency_manager.apply(*args, **kwargs)
-def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh):
+def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh):
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
@@ -24,16 +21,16 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de
origin_node_sharding_spec_dict = {}
for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)):
strategies_vector = node.strategies_vector
- setattr(node, 'best_strategy', strategies_vector[strategy_index])
- setattr(node, 'sharding_spec', strategies_vector[strategy_index].output_sharding_spec)
+ setattr(node, "best_strategy", strategies_vector[strategy_index])
+ setattr(node, "sharding_spec", strategies_vector[strategy_index].output_sharding_spec)
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].output_sharding_spec
# apply the sharding spec of parameters
for node in nodes:
- if node.op == 'call_module':
+ if node.op == "call_module":
target_module = node.graph.owning_module.get_submodule(node.target)
origin_sharding_spec = ShardingSpec(device_mesh, target_module.weight.shape, {})
- setattr(target_module.weight, 'sharding_spec', origin_sharding_spec)
+ setattr(target_module.weight, "sharding_spec", origin_sharding_spec)
target_weight_sharding_spec = node.best_strategy.input_shardings[1]
target_module.weight.data = target_module.weight.data.permute((1, 0, 2, 3))
apply(target_module.weight, target_weight_sharding_spec)
@@ -51,10 +48,10 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de
# add above dicts into graph
for node in nodes:
- if node.op != 'placeholder':
+ if node.op != "placeholder":
with mod_graph.inserting_before(node):
- input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict')
- origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict')
+ input_specs_node = mod_graph.create_node("placeholder", target="sharding_spec_convert_dict")
+ origin_specs_node = mod_graph.create_node("placeholder", target="origin_node_sharding_spec_dict")
break
return sharding_spec_convert_dict, origin_node_sharding_spec_dict
@@ -70,13 +67,13 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
node_to_index_dict = {}
index = 0
for node in nodes:
- if node.target == 'sharding_spec_convert_dict':
+ if node.target == "sharding_spec_convert_dict":
input_dict_node = node
continue
- if node.target == 'origin_node_sharding_spec_dict':
+ if node.target == "origin_node_sharding_spec_dict":
origin_dict_node = node
continue
- if not hasattr(node, 'best_strategy'):
+ if not hasattr(node, "best_strategy"):
continue
node_to_index_dict[node] = index
index += 1
@@ -84,28 +81,28 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
# add shape consistency apply function into graph
for node in nodes:
- if not hasattr(node, 'best_strategy'):
+ if not hasattr(node, "best_strategy"):
continue
with mod_graph.inserting_after(node):
- origin_spec_node = mod_graph.create_node('call_function',
- operator.getitem,
- args=(origin_dict_node, node_to_index_dict[node]))
+ origin_spec_node = mod_graph.create_node(
+ "call_function", operator.getitem, args=(origin_dict_node, node_to_index_dict[node])
+ )
with mod_graph.inserting_after(origin_spec_node):
- set_sharding_spec_node = mod_graph.create_node('call_function',
- builtins.setattr,
- args=(node, 'sharding_spec', origin_spec_node))
+ set_sharding_spec_node = mod_graph.create_node(
+ "call_function", builtins.setattr, args=(node, "sharding_spec", origin_spec_node)
+ )
for user_node in node.strategies_vector.successor_nodes:
node_index = user_node.strategies_vector.predecessor_nodes.index(node)
with mod_graph.inserting_before(user_node):
- input_specs_node = mod_graph.create_node('call_function',
- operator.getitem,
- args=(input_dict_node, node_to_index_dict[node]))
+ input_specs_node = mod_graph.create_node(
+ "call_function", operator.getitem, args=(input_dict_node, node_to_index_dict[node])
+ )
with mod_graph.inserting_before(user_node):
- sharding_spec_node = mod_graph.create_node('call_function',
- operator.getitem,
- args=(input_specs_node, node_index))
+ sharding_spec_node = mod_graph.create_node(
+ "call_function", operator.getitem, args=(input_specs_node, node_index)
+ )
with mod_graph.inserting_before(user_node):
- shape_consistency_node = mod_graph.create_node('call_function', apply, args=(node, sharding_spec_node))
+ shape_consistency_node = mod_graph.create_node("call_function", apply, args=(node, sharding_spec_node))
return gm
diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py
index 2b4a8749cfd776e3ac22d75fc2e47c2475c521d6..1720aa58da2b0f2a36befa9c4270030194fb0c32 100644
--- a/colossalai/fx/passes/meta_info_prop.py
+++ b/colossalai/fx/passes/meta_info_prop.py
@@ -31,7 +31,7 @@ class TensorMetadata(NamedTuple):
numel: int
is_tensor: bool
# TODO: we can add a list of sharding spec here, and record the sharding
- # behaviour by appending sharding spec into list.
+ # behavior by appending sharding spec into list.
def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata:
@@ -109,13 +109,13 @@ class MetaInfoProp(torch.fx.Interpreter):
return TensorMetadata(None, None, False, None, 0, False)
tensor_meta = tree_map(extract_tensor_meta, result)
- n.meta['tensor_meta'] = tensor_meta
- n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
+ n.meta["tensor_meta"] = tensor_meta
+ n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
- setattr(n, 'node_size', activation_size(n.meta.get('fwd_out', 0)) + activation_size(n.meta.get('fwd_tmp', 0)))
- setattr(n, 'fwd_flop', n.meta.get('fwd_flop', 0))
- setattr(n, 'bwd_flop', n.meta.get('bwd_flop', 0))
- n.meta['type'] = type(result)
+ setattr(n, "node_size", activation_size(n.meta.get("fwd_out", 0)) + activation_size(n.meta.get("fwd_tmp", 0)))
+ setattr(n, "fwd_flop", n.meta.get("fwd_flop", 0))
+ setattr(n, "bwd_flop", n.meta.get("bwd_flop", 0))
+ n.meta["type"] = type(result)
# retain the autograd graph
for param in self.module.parameters():
@@ -125,7 +125,7 @@ class MetaInfoProp(torch.fx.Interpreter):
# Main Node running APIs
@compatibility(is_backward_compatible=True)
- def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def placeholder(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``placeholder`` node. Note that this is stateful:
``Interpreter`` maintains an internal iterator over
@@ -146,7 +146,7 @@ class MetaInfoProp(torch.fx.Interpreter):
return super().placeholder(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
- def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def get_attr(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``get_attr`` node. Will retrieve an attribute
value from the ``Module`` hierarchy of ``self.module``.
@@ -165,7 +165,7 @@ class MetaInfoProp(torch.fx.Interpreter):
return super().get_attr(target, args, kwargs), GraphInfo()
@compatibility(is_backward_compatible=True)
- def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node with meta tensor and return the result and its meta profile.
@@ -184,7 +184,7 @@ class MetaInfoProp(torch.fx.Interpreter):
return profile_function(target)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
- def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node with meta tensor and return the result and its meta profile.
@@ -202,7 +202,7 @@ class MetaInfoProp(torch.fx.Interpreter):
return profile_method(target)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
- def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node with meta tensor and return the result and its meta profile.
@@ -224,7 +224,7 @@ class MetaInfoProp(torch.fx.Interpreter):
return profile_module(submod)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
- def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ def output(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute an ``output`` node. This really just retrieves
the value referenced by the ``output`` node and returns it.
@@ -240,7 +240,7 @@ class MetaInfoProp(torch.fx.Interpreter):
result (Any): The argument value that was retrieved
meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`.
"""
- if hasattr(args[0], '_tensor'):
+ if hasattr(args[0], "_tensor"):
return args[0], GraphInfo(fwd_in=[args[0]._tensor])
return args[0], GraphInfo(save_fwd_in=True)
@@ -257,7 +257,7 @@ class MetaInfoProp(torch.fx.Interpreter):
"""
return super().run(*args)
- def summary(self, unit: str = 'MB') -> str:
+ def summary(self, unit: str = "MB") -> str:
"""
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
@@ -267,9 +267,11 @@ class MetaInfoProp(torch.fx.Interpreter):
try:
from tabulate import tabulate
except ImportError:
- print("`summary` relies on the library `tabulate`, "
- "which could not be found on this machine. Run `pip "
- "install tabulate` to install the library.")
+ print(
+ "`summary` relies on the library `tabulate`, "
+ "which could not be found on this machine. Run `pip "
+ "install tabulate` to install the library."
+ )
assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`."
@@ -278,10 +280,10 @@ class MetaInfoProp(torch.fx.Interpreter):
def mem_repr(mem: int) -> str:
unit_divisor_map = {
- 'kb': 1024,
- 'mb': 1024**2,
- 'gb': 1024**3,
- 'tb': 1024**4,
+ "kb": 1024,
+ "mb": 1024**2,
+ "gb": 1024**3,
+ "tb": 1024**4,
}
return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}"
@@ -292,35 +294,37 @@ class MetaInfoProp(torch.fx.Interpreter):
for node in self.module.graph.nodes:
node: Node
accumulate_size += calculate_fwd_out(node) + calculate_fwd_tmp(node)
- node_summaries.append([
- node.op,
- str(node),
- flops_repr(node.meta['fwd_flop']),
- flops_repr(node.meta['bwd_flop']),
- mem_repr(accumulate_size),
- mem_repr(calculate_fwd_in(node)),
- mem_repr(calculate_fwd_out(node)),
- mem_repr(calculate_fwd_tmp(node)),
- mem_repr(node.meta['bwd_mem_out']),
- mem_repr(node.meta['bwd_mem_tmp']),
- ])
+ node_summaries.append(
+ [
+ node.op,
+ str(node),
+ flops_repr(node.meta["fwd_flop"]),
+ flops_repr(node.meta["bwd_flop"]),
+ mem_repr(accumulate_size),
+ mem_repr(calculate_fwd_in(node)),
+ mem_repr(calculate_fwd_out(node)),
+ mem_repr(calculate_fwd_tmp(node)),
+ mem_repr(node.meta["bwd_mem_out"]),
+ mem_repr(node.meta["bwd_mem_tmp"]),
+ ]
+ )
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers: List[str] = [
- 'Op type',
- 'Op',
- 'Forward FLOPs',
- 'Backward FLOPs',
- 'Accumulated Memory',
- 'FWD_IN',
- 'FWD_OUT',
- 'FWD_TMP',
- 'BWD_OUT',
- 'BWD_TMP',
+ "Op type",
+ "Op",
+ "Forward FLOPs",
+ "Backward FLOPs",
+ "Accumulated Memory",
+ "FWD_IN",
+ "FWD_OUT",
+ "FWD_TMP",
+ "BWD_OUT",
+ "BWD_TMP",
]
- return tabulate(node_summaries, headers=headers, stralign='right')
+ return tabulate(node_summaries, headers=headers, stralign="right")
def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: str = "MB", **kwargs) -> None:
@@ -344,15 +348,16 @@ def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit:
Returns:
torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo.
"""
- device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
+ device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
interp = MetaInfoProp(gm.to(device))
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
+
args = tree_map(lambda x: MetaTensor(x, fake_device=device), args)
kwargs = tree_map(lambda x: MetaTensor(x, fake_device=device), kwargs)
interp.propagate(*args, **kwargs)
if verbose:
interp.summary(unit)
- gm.to('cpu')
+ gm.to("cpu")
del interp
return gm
diff --git a/colossalai/fx/passes/passes_for_gpt2_test.py b/colossalai/fx/passes/passes_for_gpt2_test.py
index abc1a089e9a90ebaff0681879aaa68d488edb624..73379f73689c4da70feb3650a553678e57156945 100644
--- a/colossalai/fx/passes/passes_for_gpt2_test.py
+++ b/colossalai/fx/passes/passes_for_gpt2_test.py
@@ -5,7 +5,6 @@ import torch
from packaging import version
from torch.fx._compatibility import compatibility
from torch.fx.graph_module import GraphModule
-from torch.fx.node import Node
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, pipe_split
from colossalai.fx.passes.meta_info_prop import TensorMetadata
@@ -13,9 +12,9 @@ from colossalai.fx.passes.split_module import Partition
def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, partition_list: List[int]):
- '''
+ """
This pass is only used to do the gpt2 performance test, it may move into adding_split_node_pass.py, and will be deprecated in future.
- '''
+ """
mod_graph = gm.graph
valid_children_size = 0
valid_children = []
@@ -39,40 +38,40 @@ def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, parti
part_index += 1
pp_size -= 1
with mod_graph.inserting_after(node):
- split_node = mod_graph.create_node('call_function', pipe_split)
+ split_node = mod_graph.create_node("call_function", pipe_split)
gm.recompile()
return gm
def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule):
- '''
+ """
This pass will be used in gpt2 test, only a part of changes may be added into
split_with_split_nodes_pass, and it will be deprecated in future.
- '''
+ """
part_idx = 0
def eliminate_unused_placeholders(gm):
for node in gm.graph.nodes:
- if node.op == 'placeholder':
+ if node.op == "placeholder":
if not len(node.users):
gm.graph.erase_node(node)
gm.recompile()
return gm
def refill_outputs_and_placeholders(gm, next_partition_placeholders):
- '''
+ """
This method is used to eliminate the outputs in previous partition which is unused in next partition.
In split module pass, it treats partitions as a DAG, but we need treat them as a single direction linked list in pipeline parallel.
The difference is if a output from partition 0 is an input argument of partition 3, the DAG will not transfer it
to partition 1 and partition 2. However, in single direction linked list, we need to do so.
- '''
+ """
output_type = None
output_args = []
non_output_list = []
new_placeholder_list = []
for node in gm.graph.nodes:
- if node.op == 'output':
+ if node.op == "output":
if isinstance(node.args[0], (tuple, list)):
output_type = node.args[0].__class__
output_args.extend([n.name for n in node.args[0]])
@@ -114,7 +113,7 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
continue
for node in gm.graph.nodes:
- if node.op == 'placeholder':
+ if node.op == "placeholder":
new_placeholder_list.append(node.name)
if output_type is not None:
gm.graph.output(output_type(output_args))
@@ -125,7 +124,7 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
def split_callback(n: torch.fx.Node):
nonlocal part_idx
- if (n.op, n.target) == ('call_function', pipe_split):
+ if (n.op, n.target) == ("call_function", pipe_split):
part_idx += 1
return part_idx
@@ -134,7 +133,7 @@ def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule)
for name, submodule in split_mod.named_modules():
if isinstance(submodule, torch.fx.GraphModule):
for node in submodule.graph.nodes:
- if (node.op, node.target) == ('call_function', pipe_split):
+ if (node.op, node.target) == ("call_function", pipe_split):
submodule.graph.erase_node(node)
submodule.recompile()
split_submodules.append(submodule)
@@ -200,13 +199,12 @@ def split_module_for_gpt2_test(
_gen_all_ancestors_set(node)
for n in list(all_ancestors):
- if n.op != 'placeholder' and n._fx_partition > partition_name:
+ if n.op != "placeholder" and n._fx_partition > partition_name:
n._fx_partition = partition_name
- def record_cross_partition_use(def_node: torch.fx.node.Node,
- use_node: Optional[torch.fx.node.Node]): # noqa: B950
- def_partition_name = getattr(def_node, '_fx_partition', None)
- use_partition_name = getattr(use_node, '_fx_partition', None)
+ def record_cross_partition_use(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
+ def_partition_name = getattr(def_node, "_fx_partition", None)
+ use_partition_name = getattr(use_node, "_fx_partition", None)
if def_partition_name != use_partition_name:
# if 'tensor_meta' in def_node.meta:
# if not _node_with_all_tensor_element(def_node.meta['tensor_meta']):
@@ -230,14 +228,14 @@ def split_module_for_gpt2_test(
use_partition.partitions_dependent_on.setdefault(def_partition_name)
node_process_list = list(m.graph.nodes)
- # split nodes into parititons
+ # split nodes into partitions
while node_process_list:
node = node_process_list.pop(0)
orig_nodes[node.name] = node
if node.op in ["placeholder"]:
continue
- if node.op == 'output':
+ if node.op == "output":
# partition_name = str(split_callback(node))
# def _set_output_args_partition(n, partition_name):
# n._fx_partition = partition_name
@@ -252,12 +250,12 @@ def split_module_for_gpt2_test(
partitions[partition_name] = partition = Partition(partition_name)
partition.node_names.append(node.name)
- origin_partition_name = getattr(node, '_fx_partition', None)
+ origin_partition_name = getattr(node, "_fx_partition", None)
if origin_partition_name is None:
node._fx_partition = partition_name
torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node))
- torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
+ torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
# find partitions with no dependencies
root_partitions: List[str] = []
@@ -277,7 +275,7 @@ def split_module_for_gpt2_test(
if len(sorted_partitions) != len(partitions):
raise RuntimeError("cycle exists between partitions!")
- # add placeholders to parititons
+ # add placeholders to partitions
for partition_name in sorted_partitions:
partition = partitions[partition_name]
for input in partition.inputs:
@@ -287,7 +285,7 @@ def split_module_for_gpt2_test(
# Transform nodes and collect targets for partition's submodule
for node in m.graph.nodes:
- if hasattr(node, '_fx_partition'):
+ if hasattr(node, "_fx_partition"):
partition = partitions[node._fx_partition]
# swap out old graph nodes in kw/args with references to new nodes in this submodule
@@ -295,26 +293,24 @@ def split_module_for_gpt2_test(
gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n])
- if node.op not in ['call_module', 'get_attr']:
+ if node.op not in ["call_module", "get_attr"]:
target = node.target
else:
- target_atoms = node.target.split('.')
+ target_atoms = node.target.split(".")
target_attr = m
for atom in target_atoms:
if not hasattr(target_attr, atom):
- raise RuntimeError(f'Operator target {node.target} not found!')
+ raise RuntimeError(f"Operator target {node.target} not found!")
target_attr = getattr(target_attr, atom)
# target = target_atoms[-1]
- target = '_'.join(target_atoms)
+ target = "_".join(target_atoms)
partition.targets[target] = target_attr
assert isinstance(gathered_args, tuple)
assert isinstance(gathered_kwargs, dict)
- new_node = partition.graph.create_node(op=node.op,
- target=target,
- args=gathered_args,
- kwargs=gathered_kwargs,
- name=node.name)
+ new_node = partition.graph.create_node(
+ op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs, name=node.name
+ )
new_node.meta = node.meta.copy()
partition.environment[node] = new_node
@@ -323,14 +319,14 @@ def split_module_for_gpt2_test(
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
for node in m.graph.nodes:
- if node.op == 'placeholder':
- if version.parse(torch.__version__) < version.parse('1.11.0'):
+ if node.op == "placeholder":
+ if version.parse(torch.__version__) < version.parse("1.11.0"):
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type)
else:
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
- base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
- type_expr=node.type,
- default_value=default_value)
+ base_mod_env[node.name] = base_mod_graph.placeholder(
+ node.name, type_expr=node.type, default_value=default_value
+ )
base_mod_env[node.name].meta = node.meta.copy()
# Do some things iterating over the partitions in topological order again:
@@ -344,13 +340,14 @@ def split_module_for_gpt2_test(
# Set correct output values
output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs)
- output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
+ output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
partition.graph.output(output_vals)
# Construct GraphModule for this partition
- submod_name = f'submod_{partition_name}'
- base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets,
- partition.graph) # noqa: B950
+ submod_name = f"submod_{partition_name}"
+ base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(
+ partition.targets, partition.graph
+ ) # noqa: B950
# Emit call in base graph to this submodule
output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs))
@@ -358,14 +355,14 @@ def split_module_for_gpt2_test(
# Unpack multiple return values from submodule
output_val_proxy = torch.fx.proxy.Proxy(output_val)
for i, output_name in enumerate(partition.outputs):
- base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
+ base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
else:
if not partition.outputs:
continue
base_mod_env[list(partition.outputs)[0]] = output_val
for node in m.graph.nodes:
- if node.op == 'output':
- base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
+ if node.op == "output":
+ base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
diff --git a/colossalai/fx/passes/shard_1d_pass.py b/colossalai/fx/passes/shard_1d_pass.py
index d2bad06bb45a1393543039ef55dea6c8d1e9f50b..be8261f2a3f4fd408824ad4003d455217d456960 100644
--- a/colossalai/fx/passes/shard_1d_pass.py
+++ b/colossalai/fx/passes/shard_1d_pass.py
@@ -1,19 +1,32 @@
+import operator
+
import torch
import torch.nn as nn
-import operator
-from colossalai.tensor import ProcessGroup
-from colossalai.tensor.distspec import ShardSpec
-from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec
+
+from colossalai.legacy.tensor import ProcessGroup
+from colossalai.legacy.tensor.compute_spec import ComputePattern, ComputeSpec
+from colossalai.legacy.tensor.distspec import ShardSpec
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
ELEMENTWISE_FUNC_OP = [
- torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
- operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout
+ torch.add,
+ operator.add,
+ torch.abs,
+ torch.cos,
+ torch.exp,
+ torch.mul,
+ operator.mul,
+ operator.floordiv,
+ operator.truediv,
+ operator.neg,
+ torch.multiply,
+ torch.nn.functional.relu,
+ torch.nn.functional.dropout,
]
def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: bool) -> torch.nn.parameter.Parameter:
- """weight_split
+ """weight_split
split a nn.Parameter
Args:
@@ -60,9 +73,9 @@ def row_shard_linear_pass(gm: torch.fx.GraphModule):
def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: ProcessGroup):
"""
- This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers.
+ This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers.
"""
- #TODO: Needs to handle special cases, like x = linear(x) + linear(x)
+ # TODO: Needs to handle special cases, like x = linear(x) + linear(x)
graph = graph_module.graph
world_size = process_group.world_size()
@@ -70,7 +83,7 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
# traverse the graph to look for consecutive linear layers
is_linear_module = False
- if node.op == 'call_module':
+ if node.op == "call_module":
# look for the linear layer
module = node.graph.owning_module.get_submodule(node.target)
if isinstance(module, nn.Linear):
@@ -80,31 +93,31 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
# it means the first linear has been found and the current module
# is the second linear
# set the current linear module to be row-sharded
- annotation_record['row'] = module
+ annotation_record["row"] = module
for shard_type, module in annotation_record.items():
# add row sharding spec
- if shard_type == 'row':
+ if shard_type == "row":
dist_spec = ShardSpec(dims=[-1], num_partitions=[world_size])
comp_spec = ComputeSpec(ComputePattern.TP1D)
- setattr(module.weight, 'pg', process_group)
- setattr(module.weight, 'dist_spec', dist_spec)
- setattr(module.weight, 'comp_spec', comp_spec)
- elif shard_type == 'col':
+ setattr(module.weight, "pg", process_group)
+ setattr(module.weight, "dist_spec", dist_spec)
+ setattr(module.weight, "comp_spec", comp_spec)
+ elif shard_type == "col":
weight_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
weight_comp_spec = ComputeSpec(ComputePattern.TP1D)
weight_comp_spec.output_replicate = False
- setattr(module.weight, 'pg', process_group)
- setattr(module.weight, 'dist_spec', weight_dist_spec)
- setattr(module.weight, 'comp_spec', weight_comp_spec)
+ setattr(module.weight, "pg", process_group)
+ setattr(module.weight, "dist_spec", weight_dist_spec)
+ setattr(module.weight, "comp_spec", weight_comp_spec)
if module.bias is not None:
bias_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size])
bias_comp_spec = ComputeSpec(ComputePattern.TP1D)
bias_comp_spec.output_replicate = False
- setattr(module.bias, 'pg', process_group)
- setattr(module.bias, 'dist_spec', bias_dist_spec)
- setattr(module.bias, 'comp_spec', bias_comp_spec)
+ setattr(module.bias, "pg", process_group)
+ setattr(module.bias, "dist_spec", bias_dist_spec)
+ setattr(module.bias, "comp_spec", bias_comp_spec)
start_tracking = False
annotation_record.clear()
else:
@@ -112,16 +125,16 @@ def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: Proc
# it means the current layer is the first linear
# set the linear layer to be col-sharded
start_tracking = True
- annotation_record['col'] = module
+ annotation_record["col"] = module
if start_tracking and not is_linear_module:
# check against the white list
# if non-element wise op is found, we reset the tracking
- if node.op == 'call_module':
+ if node.op == "call_module":
module = node.graph.owning_module.get_submodule(node.target)
if module.__class__ not in ELEMENTWISE_MODULE_OP:
start_tracking = False
- elif node.op == 'call_function' or node.op == 'call_method':
+ elif node.op == "call_function" or node.op == "call_method":
if node.target not in ELEMENTWISE_FUNC_OP:
start_tracking = False
elif len(node.users.keys()) > 1:
diff --git a/colossalai/fx/passes/split_module.py b/colossalai/fx/passes/split_module.py
index 5ce5b969cbdefc0e347eb9f156f5268658e988fc..67a2432595d691197861fb5d72c27ce137dbd808 100644
--- a/colossalai/fx/passes/split_module.py
+++ b/colossalai/fx/passes/split_module.py
@@ -25,12 +25,14 @@ class Partition:
self.targets: Dict[str, Any] = {}
def __repr__(self) -> str:
- return f"name: {self.name},\n" \
- f" nodes: {self.node_names},\n" \
- f" inputs: {self.inputs},\n" \
- f" outputs: {self.outputs},\n" \
- f" partitions depenent on: {self.partitions_dependent_on},\n" \
- f" parition dependents: {self.partition_dependents}"
+ return (
+ f"name: {self.name},\n"
+ f" nodes: {self.node_names},\n"
+ f" inputs: {self.inputs},\n"
+ f" outputs: {self.outputs},\n"
+ f" partitions dependent on: {self.partitions_dependent_on},\n"
+ f" partition dependents: {self.partition_dependents}"
+ )
# Creates subgraphs out of main graph
@@ -117,10 +119,9 @@ def split_module(
partitions: Dict[str, Partition] = {}
orig_nodes: Dict[str, torch.fx.node.Node] = {}
- def record_cross_partition_use(def_node: torch.fx.node.Node,
- use_node: Optional[torch.fx.node.Node]): # noqa: B950
- def_partition_name = getattr(def_node, '_fx_partition', None)
- use_partition_name = getattr(use_node, '_fx_partition', None)
+ def record_cross_partition_use(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
+ def_partition_name = getattr(def_node, "_fx_partition", None)
+ use_partition_name = getattr(use_node, "_fx_partition", None)
if def_partition_name != use_partition_name:
if def_partition_name is not None:
def_partition = partitions[def_partition_name]
@@ -134,7 +135,7 @@ def split_module(
if def_partition_name is not None:
use_partition.partitions_dependent_on.setdefault(def_partition_name)
- def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
+ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
def_partition_name = getattr(def_node, "_fx_partition", None)
use_partition_name = getattr(use_node, "_fx_partition", None)
if def_partition_name != use_partition_name:
@@ -161,7 +162,7 @@ def split_module(
if node.op in ["placeholder"]:
continue
- if node.op == 'output':
+ if node.op == "output":
if merge_output:
torch.fx.graph.map_arg(node.args[0], lambda n: record_output(n, node.prev))
else:
@@ -178,7 +179,7 @@ def split_module(
node._fx_partition = partition_name
torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node))
- torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
+ torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950
# find partitions with no dependencies
root_partitions: List[str] = []
@@ -208,7 +209,7 @@ def split_module(
# Transform nodes and collect targets for partition's submodule
for node in m.graph.nodes:
- if hasattr(node, '_fx_partition'):
+ if hasattr(node, "_fx_partition"):
partition = partitions[node._fx_partition]
# swap out old graph nodes in kw/args with references to new nodes in this submodule
@@ -216,25 +217,24 @@ def split_module(
gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n])
- if node.op not in ['call_module', 'get_attr']:
+ if node.op not in ["call_module", "get_attr"]:
target = node.target
else:
- target_atoms = node.target.split('.')
+ target_atoms = node.target.split(".")
target_attr = m
for atom in target_atoms:
if not hasattr(target_attr, atom):
- raise RuntimeError(f'Operator target {node.target} not found!')
+ raise RuntimeError(f"Operator target {node.target} not found!")
target_attr = getattr(target_attr, atom)
# target = target_atoms[-1]
- target = '_'.join(target_atoms)
+ target = "_".join(target_atoms)
partition.targets[target] = target_attr
assert isinstance(gathered_args, tuple)
assert isinstance(gathered_kwargs, dict)
- new_node = partition.graph.create_node(op=node.op,
- target=target,
- args=gathered_args,
- kwargs=gathered_kwargs)
+ new_node = partition.graph.create_node(
+ op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs
+ )
new_node.meta = node.meta.copy()
partition.environment[node] = new_node
@@ -243,14 +243,14 @@ def split_module(
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
for node in m.graph.nodes:
- if node.op == 'placeholder':
- if version.parse(torch.__version__) < version.parse('1.11.0'):
+ if node.op == "placeholder":
+ if version.parse(torch.__version__) < version.parse("1.11.0"):
base_mod_env[node.name] = base_mod_graph.placeholder(node.target, type_expr=node.type)
else:
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
- base_mod_env[node.name] = base_mod_graph.placeholder(node.target,
- type_expr=node.type,
- default_value=default_value)
+ base_mod_env[node.name] = base_mod_graph.placeholder(
+ node.target, type_expr=node.type, default_value=default_value
+ )
base_mod_env[node.name].meta = node.meta.copy()
# Do some things iterating over the partitions in topological order again:
@@ -264,13 +264,14 @@ def split_module(
# Set correct output values
output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs)
- output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
+ output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
partition.graph.output(output_vals)
# Construct GraphModule for this partition
- submod_name = f'submod_{partition_name}'
- base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets,
- partition.graph) # noqa: B950
+ submod_name = f"submod_{partition_name}"
+ base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(
+ partition.targets, partition.graph
+ ) # noqa: B950
# Emit call in base graph to this submodule
output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs))
@@ -278,15 +279,15 @@ def split_module(
# Unpack multiple return values from submodule
output_val_proxy = torch.fx.proxy.Proxy(output_val)
for i, output_name in enumerate(partition.outputs):
- base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
+ base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
else:
if not partition.outputs:
continue
base_mod_env[list(partition.outputs)[0]] = output_val
for node in m.graph.nodes:
- if node.op == 'output':
- base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
+ if node.op == "output":
+ base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950
for partition_name in sorted_partitions:
partition = partitions[partition_name]
diff --git a/colossalai/fx/passes/utils.py b/colossalai/fx/passes/utils.py
index bb4f3cd6a4908177ca13f9d7fb82ff42b5ad1d5e..c51f49a30e8ae4c703ce2f7820f59ad901f4bbb0 100644
--- a/colossalai/fx/passes/utils.py
+++ b/colossalai/fx/passes/utils.py
@@ -1,7 +1,9 @@
-import torch
from typing import Dict
-from torch.fx.node import Node, map_arg
+
+import torch
from torch.fx.graph import Graph
+from torch.fx.node import Node, map_arg
+
def get_comm_size(prev_partition, next_partition):
"""
@@ -23,7 +25,7 @@ def get_comm_size(prev_partition, next_partition):
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
for n in input_nodes:
if n.name in parent_node_names and n not in visited_nodes:
- comm_size += n.meta['tensor_meta'].numel
+ comm_size += n.meta["tensor_meta"].numel
visited_nodes.add(n)
return comm_size
@@ -36,12 +38,12 @@ def get_leaf(graph: Graph):
"""
input_nodes: Dict[Node, None] = {}
for node in graph.nodes:
- if node.op == 'output':
+ if node.op == "output":
map_arg(node.args, lambda n: input_nodes.setdefault(n))
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
placeholder_nodes = []
for node in input_nodes.keys():
- if node.op == 'placeholder':
+ if node.op == "placeholder":
placeholder_nodes.append(node)
for node in placeholder_nodes:
input_nodes.pop(node)
@@ -60,13 +62,13 @@ def get_top(graph: Graph):
"""
top_node_list = set()
for node in graph.nodes:
- if node.op == 'output':
+ if node.op == "output":
continue
is_top = False
def _get_top(node):
nonlocal is_top
- if node.op == 'placeholder':
+ if node.op == "placeholder":
is_top = True
map_arg(node.args, lambda n: _get_top(n))
@@ -83,7 +85,7 @@ def is_top(graph: Graph, node: Node):
def get_all_consumers(graph: Graph, node: Node):
"""
Given a graph and a node of this graph, return all consumers of the node.
-
+
Returns:
List of ``Nodes`` that node appear in these nodes ``args`` and ``kwargs``.
"""
@@ -120,7 +122,7 @@ def assign_bfs_level_to_nodes(graph: Graph):
for node in gm.graph.nodes:
if hasattr(node, 'bfs_level'):
print(node.name, node.bfs_level)
-
+
Output:
graph():
%x : [#users=2] = placeholder[target=x]
@@ -148,7 +150,7 @@ def assign_bfs_level_to_nodes(graph: Graph):
while nodes_to_process:
new_process_list = []
for node in nodes_to_process:
- if node.op == 'output':
+ if node.op == "output":
continue
node.bfs_level = current_level
new_process_list.extend(get_all_consumers(graph, node))
@@ -165,8 +167,9 @@ def get_node_module(node) -> torch.nn.Module:
torch.nn.Module: the module associated with the given node
"""
- assert node.graph.owning_module is not None, 'Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object'
- assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}'
+ assert (
+ node.graph.owning_module is not None
+ ), "Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object"
+ assert node.op == "call_module", f"Expected node.op to be call_module, but found {node.op}"
module = node.graph.owning_module.get_submodule(node.target)
return module
-
diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py
index 8bcbde0eb23b806b7e37e407d57a962d8ff71573..89dd2b3df617e352b324d9dab421bea69a84a6a1 100644
--- a/colossalai/fx/profiler/__init__.py
+++ b/colossalai/fx/profiler/__init__.py
@@ -12,7 +12,16 @@ if is_compatible_with_meta():
)
from .tensor import MetaTensor
else:
- from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
+ from .experimental import (
+ meta_profiler_function,
+ meta_profiler_module,
+ profile_function,
+ profile_method,
+ profile_module,
+ calculate_fwd_in,
+ calculate_fwd_tmp,
+ calculate_fwd_out,
+ )
from .dataflow import GraphInfo
from .memory_utils import activation_size, is_inplace, parameter_size
diff --git a/colossalai/fx/profiler/constants.py b/colossalai/fx/profiler/constants.py
index 5763a46dc83f19dadefbebb32dbcf9a59578a2b3..fad9bb272bff76022928aa518f99bd92ffb27892 100644
--- a/colossalai/fx/profiler/constants.py
+++ b/colossalai/fx/profiler/constants.py
@@ -1,6 +1,6 @@
import torch
-__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN', 'RELU_LIKE_OPS', 'RELU_LIKE_MOD']
+__all__ = ["ALIAS_ATEN", "INPLACE_NEW", "INPLACE_MATH_ATEN", "CLONE_ATEN", "RELU_LIKE_OPS", "RELU_LIKE_MOD"]
aten = torch.ops.aten
diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py
index a5e8880322b84f54e6c4742a821f4de76dfb664a..05f9b50ce575312011d9b0ac9937959f7b22e298 100644
--- a/colossalai/fx/profiler/dataflow.py
+++ b/colossalai/fx/profiler/dataflow.py
@@ -1,6 +1,5 @@
from dataclasses import dataclass, field
from enum import Enum
-from functools import partial
from typing import Dict, List
from torch.fx import Graph, Node
@@ -69,8 +68,8 @@ class GraphInfo:
def is_phase(n: Node, phase: Phase) -> bool:
- assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!'
- return n.meta['phase'] == phase
+ assert "phase" in n.meta, f"Node meta of {n} has no key `phase`!"
+ return n.meta["phase"] == phase
@compatibility(is_backward_compatible=False)
@@ -103,9 +102,9 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
peak_mem = 0
for k, v in deps.items():
if v > 0 and is_phase(k, Phase.BACKWARD) and not all(map(is_inplace, k.users)) and not is_inplace(k):
- peak_mem += activation_size(k.meta['saved_tensor'])
- if v <= float('-inf') and is_phase(k, Phase.FORWARD):
- peak_mem -= activation_size(k.meta['saved_tensor'])
+ peak_mem += activation_size(k.meta["saved_tensor"])
+ if v <= float("-inf") and is_phase(k, Phase.FORWARD):
+ peak_mem -= activation_size(k.meta["saved_tensor"])
return peak_mem
# deps is used to track all the memory dependencies of the graph.
@@ -123,19 +122,19 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
# the node, `fwd_mem_tmp` can be freed.
if is_phase(n, Phase.PLACEHOLDER):
- graph_info.fwd_in += n.meta['saved_tensor']
+ graph_info.fwd_in += n.meta["saved_tensor"]
if is_phase(n, Phase.FORWARD):
- graph_info.fwd_tmp += n.meta['saved_tensor']
+ graph_info.fwd_tmp += n.meta["saved_tensor"]
elif is_phase(n, Phase.BACKWARD):
if len(n.users):
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
else:
# TODO: some of the bwd_mem_out might be model parameters.
# basically a backward node without user is a `grad_out` node
- graph_info.bwd_mem_out += activation_size(n.meta['saved_tensor'])
+ graph_info.bwd_mem_out += activation_size(n.meta["saved_tensor"])
for input_n in n.all_input_nodes:
if input_n in deps:
deps[input_n] -= 1
if deps[input_n] <= 0:
- deps[input_n] = float('-inf')
+ deps[input_n] = float("-inf")
return graph_info
diff --git a/colossalai/fx/profiler/experimental/constants.py b/colossalai/fx/profiler/experimental/constants.py
index 57ff3fd91299b5bb8938125bf2d3243c9a9c4c2b..02758e7643af0936a5c7907ca6675ef10ec4db36 100644
--- a/colossalai/fx/profiler/experimental/constants.py
+++ b/colossalai/fx/profiler/experimental/constants.py
@@ -2,7 +2,7 @@ from operator import add, floordiv, getitem, mul, neg, pos, setitem, sub
import torch
-__all__ = ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']
+__all__ = ["INPLACE_OPS", "INPLACE_METHOD", "NON_INPLACE_METHOD"]
# TODO fill out the inplace ops
INPLACE_OPS = [
@@ -20,25 +20,25 @@ INPLACE_OPS = [
# TODO: list all call_methods that are inplace here
INPLACE_METHOD = [
- 'transpose',
- 'permute',
+ "transpose",
+ "permute",
# TODO: reshape may return a copy of the data if the data is not contiguous
- 'reshape',
- 'dim',
- 'flatten',
- 'size',
- 'view',
- 'unsqueeze',
- 'to',
- 'type',
- 'flatten',
+ "reshape",
+ "dim",
+ "flatten",
+ "size",
+ "view",
+ "unsqueeze",
+ "to",
+ "type",
+ "flatten",
]
# TODO: list all call_methods that are not inplace here
NON_INPLACE_METHOD = [
- 'chunk',
- 'contiguous',
- 'expand',
- 'mean',
- 'split',
+ "chunk",
+ "contiguous",
+ "expand",
+ "mean",
+ "split",
]
diff --git a/colossalai/fx/profiler/experimental/profiler.py b/colossalai/fx/profiler/experimental/profiler.py
index 5c545260e72b723bfa54beacfb20def3e758413f..d890fdb66fc2dc7497a6c6bf2f4a2534c1885d60 100644
--- a/colossalai/fx/profiler/experimental/profiler.py
+++ b/colossalai/fx/profiler/experimental/profiler.py
@@ -9,7 +9,7 @@ from ..memory_utils import activation_size
from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD
from .registry import meta_profiler_function, meta_profiler_module
-__all__ = ['profile_function', 'profile_module', 'profile_method']
+__all__ = ["profile_function", "profile_module", "profile_method"]
# this is for compatibility use
@@ -42,6 +42,7 @@ class GraphInfo:
bwd_mem_tmp (int): See the above illustration.
bwd_mem_out (int): See the above illustration.
"""
+
fwd_flop: int = 0
bwd_flop: int = 0
fwd_mem_in: int = 0
@@ -50,8 +51,7 @@ class GraphInfo:
bwd_mem_out: int = 0
-CALL_FUNCTION_MSG = \
-"""
+CALL_FUNCTION_MSG = """
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler.experimental import meta_profiler_function
@meta_profiler_function.register(YOUR_FUNCTION)
@@ -60,9 +60,8 @@ def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]:
macs = ...
return flops, macs
"""
-CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}'
-CALL_MODULE_MSG = \
-"""
+CALL_METHOD_MSG = "Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}"
+CALL_MODULE_MSG = """
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler.experimental import meta_profiler_module
@meta_profiler_module.register(YOUR_MODULE)
@@ -74,7 +73,7 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int
@compatibility(is_backward_compatible=True)
-def profile_function(target: 'Target') -> Callable:
+def profile_function(target: "Target") -> Callable:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution.
@@ -92,12 +91,13 @@ def profile_function(target: 'Target') -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_function.has(target) or meta_profiler_function.has(
- target.__name__), CALL_FUNCTION_MSG.format(target)
+ target.__name__
+ ), CALL_FUNCTION_MSG.format(target)
fwd_tmp = 0
fwd_out = 0
out = func(*args, **kwargs)
- if target not in INPLACE_OPS and not kwargs.get('inplace', False):
+ if target not in INPLACE_OPS and not kwargs.get("inplace", False):
fwd_out = activation_size(out)
if meta_profiler_function.has(target):
profiler = meta_profiler_function.get(target)
@@ -112,7 +112,7 @@ def profile_function(target: 'Target') -> Callable:
@compatibility(is_backward_compatible=True)
-def profile_method(target: 'Target') -> Callable:
+def profile_method(target: "Target") -> Callable:
"""
Wrap a `call_method` node
record the memory cost and FLOPs of the execution.
@@ -126,11 +126,12 @@ def profile_method(target: 'Target') -> Callable:
self_obj, *args_tail = args
# execute the method and return the result
- assert isinstance(target, str), f'{target} instance is not str.'
+ assert isinstance(target, str), f"{target} instance is not str."
out = getattr(self_obj, target)(*args_tail, **kwargs)
assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format(
- target, INPLACE_METHOD, NON_INPLACE_METHOD)
+ target, INPLACE_METHOD, NON_INPLACE_METHOD
+ )
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out)
fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out)
@@ -161,7 +162,7 @@ def profile_module(module: torch.nn.Module) -> Callable:
fwd_tmp = 0
fwd_out = 0
out = func(*args, **kwargs)
- if getattr(module, 'inplace', False):
+ if getattr(module, "inplace", False):
fwd_out = activation_size(out)
profiler = meta_profiler_module.get(type(module))
fwd_flop, _ = profiler(module, *args, **kwargs)
diff --git a/colossalai/fx/profiler/experimental/profiler_function/activation_function.py b/colossalai/fx/profiler/experimental/profiler_function/activation_function.py
index a43aef063e197de12c23fc5a81fb13e8183eaae9..c518ec28da418ed86716a5839dace4f1fd8bc160 100644
--- a/colossalai/fx/profiler/experimental/profiler_function/activation_function.py
+++ b/colossalai/fx/profiler/experimental/profiler_function/activation_function.py
@@ -1,5 +1,7 @@
from typing import Tuple
+
import torch
+
from ..registry import meta_profiler_function
# TODO: different activation has different FLOPs count, currently unused.
diff --git a/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py b/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py
index 8d1c8a8c6877fc12a7ad47b7b0103c309b8ed597..f1b9bb97c6c679c9b297f9feda3deec5e490631e 100644
--- a/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py
+++ b/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py
@@ -41,15 +41,15 @@ def _elementwise_flops_compute(input, other):
@meta_profiler_function.register(torch.sub)
@meta_profiler_function.register(torch.mul)
@meta_profiler_function.register(torch.floor_divide)
-@meta_profiler_function.register('add') # for built-in op +
-@meta_profiler_function.register('iadd') # for built-in op +=
-@meta_profiler_function.register('eq') # for built-in op =
-@meta_profiler_function.register('sub') # for built-in op -
-@meta_profiler_function.register('isub') # for built-in op -=
-@meta_profiler_function.register('mul') # for built-in op *
-@meta_profiler_function.register('imul') # for built-in op *=
-@meta_profiler_function.register('floordiv') # for built-in op //
-@meta_profiler_function.register('ifloordiv') # for built-in op //=
+@meta_profiler_function.register("add") # for built-in op +
+@meta_profiler_function.register("iadd") # for built-in op +=
+@meta_profiler_function.register("eq") # for built-in op =
+@meta_profiler_function.register("sub") # for built-in op -
+@meta_profiler_function.register("isub") # for built-in op -=
+@meta_profiler_function.register("mul") # for built-in op *
+@meta_profiler_function.register("imul") # for built-in op *=
+@meta_profiler_function.register("floordiv") # for built-in op //
+@meta_profiler_function.register("ifloordiv") # for built-in op //=
def torch_add_like_ops(input: Any, other: Any, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
return _elementwise_flops_compute(input, other)
@@ -62,7 +62,7 @@ def torch_elementwise_op(input: torch.Tensor, *, out: Optional[torch.Tensor] = N
@meta_profiler_function.register(torch.matmul)
-@meta_profiler_function.register('matmul') # for built-in op @
+@meta_profiler_function.register("matmul") # for built-in op @
@meta_profiler_function.register(torch.Tensor.matmul)
def torch_matmul(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
macs = reduce(operator.mul, input.shape) * other.shape[-1]
@@ -78,13 +78,15 @@ def torch_bmm(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.T
@meta_profiler_function.register(torch.var_mean)
-def torch_var_mean(input: torch.Tensor,
- dim: Union[int, Tuple[int, ...]],
- unbiased: Optional[bool] = True,
- keepdim: Optional[bool] = False,
- *,
- out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
- assert out is None, 'saving to out is not supported yet'
+def torch_var_mean(
+ input: torch.Tensor,
+ dim: Union[int, Tuple[int, ...]],
+ unbiased: Optional[bool] = True,
+ keepdim: Optional[bool] = False,
+ *,
+ out: Optional[torch.Tensor] = None,
+) -> Tuple[int, int]:
+ assert out is None, "saving to out is not supported yet"
flops = input.numel() * 3
macs = 0
return flops, macs
diff --git a/colossalai/fx/profiler/experimental/profiler_function/embedding.py b/colossalai/fx/profiler/experimental/profiler_function/embedding.py
index d6e43d781b8b64ab78cf3299daba3df1d17a5420..1d362015fc8b8c8fcfbcf5985e459f1de0aaa5b9 100644
--- a/colossalai/fx/profiler/experimental/profiler_function/embedding.py
+++ b/colossalai/fx/profiler/experimental/profiler_function/embedding.py
@@ -1,5 +1,7 @@
-import torch
from typing import Optional
+
+import torch
+
from ..registry import meta_profiler_function
diff --git a/colossalai/fx/profiler/experimental/profiler_function/linear.py b/colossalai/fx/profiler/experimental/profiler_function/linear.py
index 01fe4c87137083db2458c560a88cc6faa0af377e..ecc578d61b911c4e09086035a0f35082aa1ddd55 100644
--- a/colossalai/fx/profiler/experimental/profiler_function/linear.py
+++ b/colossalai/fx/profiler/experimental/profiler_function/linear.py
@@ -1,5 +1,7 @@
from typing import Tuple
+
import torch
+
from ..registry import meta_profiler_function
diff --git a/colossalai/fx/profiler/experimental/profiler_function/normalization.py b/colossalai/fx/profiler/experimental/profiler_function/normalization.py
index c4ea508d70f80f33bbc8ae354e9743a4939d5e8c..2ad029eda03922e782e1f42aaafa2ca7577eb7db 100644
--- a/colossalai/fx/profiler/experimental/profiler_function/normalization.py
+++ b/colossalai/fx/profiler/experimental/profiler_function/normalization.py
@@ -1,5 +1,7 @@
from typing import List, Optional, Tuple
+
import torch
+
from ..registry import meta_profiler_function
@@ -21,11 +23,13 @@ def torch_nn_func_instancenorm(
@meta_profiler_function.register(torch.nn.functional.group_norm)
-def torch_nn_func_groupnorm(input: torch.Tensor,
- num_groups: int,
- weight: Optional[torch.Tensor] = None,
- bias: Optional[torch.Tensor] = None,
- eps: float = 1e-5) -> Tuple[int, int]:
+def torch_nn_func_groupnorm(
+ input: torch.Tensor,
+ num_groups: int,
+ weight: Optional[torch.Tensor] = None,
+ bias: Optional[torch.Tensor] = None,
+ eps: float = 1e-5,
+) -> Tuple[int, int]:
has_affine = weight is not None
flops = input.numel() * (5 if has_affine else 4)
macs = 0
diff --git a/colossalai/fx/profiler/experimental/profiler_function/pooling.py b/colossalai/fx/profiler/experimental/profiler_function/pooling.py
index a639f5ee83c1f4d2b75a3c120ea6ae3884fc422f..c91deab906d44c49201da4e08600d677ae8daccc 100644
--- a/colossalai/fx/profiler/experimental/profiler_function/pooling.py
+++ b/colossalai/fx/profiler/experimental/profiler_function/pooling.py
@@ -1,5 +1,7 @@
-from typing import Tuple, Union
+from typing import Tuple
+
import torch
+
from ..registry import meta_profiler_function
diff --git a/colossalai/fx/profiler/experimental/profiler_function/python_ops.py b/colossalai/fx/profiler/experimental/profiler_function/python_ops.py
index 1e8561206ba0e7a874b202a31dd13a040533d1db..58c9889ad98e24f5f163107f5095c07c8bf7f3a7 100644
--- a/colossalai/fx/profiler/experimental/profiler_function/python_ops.py
+++ b/colossalai/fx/profiler/experimental/profiler_function/python_ops.py
@@ -1,6 +1,6 @@
import operator
from typing import Any, Tuple
-import torch
+
from ..registry import meta_profiler_function
diff --git a/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py b/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py
index abdd7ad565ba237d7d6eab9e3c9b77d7afb10abf..67e90fb69acd0a1eebb4343118c601adbfd38e1a 100644
--- a/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py
+++ b/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py
@@ -1,7 +1,9 @@
-from functools import reduce
import operator
+from functools import reduce
from typing import Any, Optional, Tuple
+
import torch
+
from ..registry import meta_profiler_function
@@ -43,13 +45,11 @@ def torch_where(condition: torch.Tensor, x: Any, y: Any) -> Tuple[int, int]:
@meta_profiler_function.register(torch.max)
-def torch_max(input: torch.Tensor,
- dim: int = None,
- keepdim: bool = False,
- *,
- out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
+def torch_max(
+ input: torch.Tensor, dim: int = None, keepdim: bool = False, *, out: Optional[torch.Tensor] = None
+) -> Tuple[int, int]:
macs = 0
- assert out is None, 'assigning value to out is not supported yet'
+ assert out is None, "assigning value to out is not supported yet"
if dim is not None:
shape = list(input.shape)
shape.pop(int(dim))
diff --git a/colossalai/fx/profiler/experimental/profiler_module/activation_function.py b/colossalai/fx/profiler/experimental/profiler_module/activation_function.py
index 2ebf514ad2699cc4e71741b9c3e143cedcb63041..ae065e0c7c176461090e700640c4dba0436174bf 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/activation_function.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/activation_function.py
@@ -1,5 +1,7 @@
from typing import Tuple
+
import torch
+
from ..registry import meta_profiler_module
# TODO: different activation has different FLOPs count, currently unused.
diff --git a/colossalai/fx/profiler/experimental/profiler_module/attention.py b/colossalai/fx/profiler/experimental/profiler_module/attention.py
index 8daf74b232bf91d41933a2184e7b0c30d516d51a..dfaee75e04320a0451b7cf7eb02154082a8c5209 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/attention.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/attention.py
@@ -1,19 +1,23 @@
from typing import Optional, Tuple
+
import torch
+
from ..registry import meta_profiler_module
# TODO: This is hard to compute memory cost
@meta_profiler_module.register(torch.nn.MultiheadAttention)
-def torch_nn_msa(self: torch.nn.MultiheadAttention,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- key_padding_mask: Optional[torch.Tensor] = None,
- need_weights: bool = True,
- attn_mask: Optional[torch.Tensor] = None,
- average_attn_weights: bool = True) -> Tuple[int, int]:
- if getattr(self, 'batch_first', False):
+def torch_nn_msa(
+ self: torch.nn.MultiheadAttention,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ key_padding_mask: Optional[torch.Tensor] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[torch.Tensor] = None,
+ average_attn_weights: bool = True,
+) -> Tuple[int, int]:
+ if getattr(self, "batch_first", False):
batch_size = query.shape[0]
len_idx = 1
else:
@@ -44,15 +48,9 @@ def torch_nn_msa(self: torch.nn.MultiheadAttention,
flops += qlen * qdim
# Initial projections
- flops += 2 * ((qlen * qdim * qdim) # QW
- + (klen * kdim * kdim) # KW
- + (vlen * vdim * vdim) # VW
- )
+ flops += 2 * ((qlen * qdim * qdim) + (klen * kdim * kdim) + (vlen * vdim * vdim)) # QW # KW # VW
- macs += ((qlen * qdim * qdim) # QW
- + (klen * kdim * kdim) # KW
- + (vlen * vdim * vdim) # VW
- )
+ macs += (qlen * qdim * qdim) + (klen * kdim * kdim) + (vlen * vdim * vdim) # QW # KW # VW
if self.in_proj_bias is not None:
flops += (qlen + klen + vlen) * qdim
@@ -62,13 +60,9 @@ def torch_nn_msa(self: torch.nn.MultiheadAttention,
v_head_dim = vdim // num_heads
head_flops = (
- 2 * (qlen * klen * qk_head_dim) # QK^T
- + (qlen * klen) # softmax
- + 2 * (qlen * klen * v_head_dim) # AV
+ 2 * (qlen * klen * qk_head_dim) + (qlen * klen) + 2 * (qlen * klen * v_head_dim) # QK^T # softmax # AV
)
- head_macs = ((qlen * klen * qk_head_dim) # QK^T
- + 2 * (qlen * klen * v_head_dim) # AV
- )
+ head_macs = (qlen * klen * qk_head_dim) + 2 * (qlen * klen * v_head_dim) # QK^T # AV
flops += num_heads * head_flops
macs += num_heads * head_flops
diff --git a/colossalai/fx/profiler/experimental/profiler_module/convolution.py b/colossalai/fx/profiler/experimental/profiler_module/convolution.py
index a4c15b91e611d5d1398eeb38ec5c106c2652749b..90e494c77f5b5d38e942e4e749f6400e8886868e 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/convolution.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/convolution.py
@@ -17,8 +17,9 @@ def torch_nn_conv1d(self: torch.nn.Conv1d, input: torch.Tensor) -> Tuple[int, in
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
c_in, l_in = input.shape[-2:]
c_out = self.out_channels
- l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] *
- (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
+ l_out = math.floor(
+ (l_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
+ )
result_shape = input.shape[:-2] + (
c_out,
l_out,
@@ -38,10 +39,12 @@ def torch_nn_conv2d(self: torch.nn.Conv2d, input: torch.Tensor) -> Tuple[int, in
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
c_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels
- h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] *
- (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
- w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] *
- (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
+ h_out = math.floor(
+ (h_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
+ )
+ w_out = math.floor(
+ (w_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
+ )
result_shape = input.shape[:-3] + (
c_out,
h_out,
@@ -62,12 +65,15 @@ def torch_nn_conv3d(self: torch.nn.Conv3d, input: torch.Tensor) -> Tuple[int, in
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html
c_in, d_in, h_in, w_in = input.shape[-4:]
c_out = self.out_channels
- d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] *
- (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
- h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] *
- (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
- w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] *
- (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1)
+ d_out = math.floor(
+ (d_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
+ )
+ h_out = math.floor(
+ (h_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
+ )
+ w_out = math.floor(
+ (w_in + 2 * self.padding[2] - self.dilation[2] * (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1
+ )
result_shape = input.shape[:-4] + (
c_out,
d_out,
@@ -89,8 +95,13 @@ def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
c_in, l_in = input.shape[-2:]
c_out = self.out_channels
- l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
- (self.kernel_size[0] - 1) + self.output_padding[0] + 1)
+ l_out = math.floor(
+ (l_in - 1) * self.stride[0]
+ - 2 * self.padding[0]
+ + self.dilation[0] * (self.kernel_size[0] - 1)
+ + self.output_padding[0]
+ + 1
+ )
result_shape = input.shape[:-2] + (
c_out,
l_out,
@@ -98,7 +109,7 @@ def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor
macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups
num_elem = reduce(
operator.mul, input.shape
- ) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604
+ ) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
@@ -112,10 +123,20 @@ def torch_nn_convtranspose2d(self: torch.nn.ConvTranspose2d, input: torch.Tensor
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
c_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels
- h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
- (self.kernel_size[0] - 1) + self.output_padding[0] + 1)
- w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
- (self.kernel_size[1] - 1) + self.output_padding[1] + 1)
+ h_out = math.floor(
+ (h_in - 1) * self.stride[0]
+ - 2 * self.padding[0]
+ + self.dilation[0] * (self.kernel_size[0] - 1)
+ + self.output_padding[0]
+ + 1
+ )
+ w_out = math.floor(
+ (w_in - 1) * self.stride[1]
+ - 2 * self.padding[1]
+ + self.dilation[1] * (self.kernel_size[1] - 1)
+ + self.output_padding[1]
+ + 1
+ )
result_shape = input.shape[:-3] + (
c_out,
h_out,
@@ -136,12 +157,27 @@ def torch_nn_convtranspose3d(self: torch.nn.ConvTranspose3d, input: torch.Tensor
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html
c_in, d_in, h_in, w_in = input.shape[-4:]
c_out = self.out_channels
- d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
- (self.kernel_size[0] - 1) + self.output_padding[0] + 1)
- h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
- (self.kernel_size[1] - 1) + self.output_padding[1] + 1)
- w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] *
- (self.kernel_size[2] - 1) + self.output_padding[2] + 1)
+ d_out = math.floor(
+ (d_in - 1) * self.stride[0]
+ - 2 * self.padding[0]
+ + self.dilation[0] * (self.kernel_size[0] - 1)
+ + self.output_padding[0]
+ + 1
+ )
+ h_out = math.floor(
+ (h_in - 1) * self.stride[1]
+ - 2 * self.padding[1]
+ + self.dilation[1] * (self.kernel_size[1] - 1)
+ + self.output_padding[1]
+ + 1
+ )
+ w_out = math.floor(
+ (w_in - 1) * self.stride[2]
+ - 2 * self.padding[2]
+ + self.dilation[2] * (self.kernel_size[2] - 1)
+ + self.output_padding[2]
+ + 1
+ )
result_shape = input.shape[:-4] + (
c_out,
d_out,
diff --git a/colossalai/fx/profiler/experimental/profiler_module/dropout.py b/colossalai/fx/profiler/experimental/profiler_module/dropout.py
index 417e0ed468637a5ce049ffa8137a73e5b266c971..7361239eb1bdf087c4ae3fd457e121cc8390ffa7 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/dropout.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/dropout.py
@@ -1,5 +1,7 @@
from typing import Tuple
+
import torch
+
from ..registry import meta_profiler_module
diff --git a/colossalai/fx/profiler/experimental/profiler_module/linear.py b/colossalai/fx/profiler/experimental/profiler_module/linear.py
index e1ffb6f244d2ed7d5764339d61fdb46f71ae59a2..71fed3196c1326c0c3121d9f6357548d699536ca 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/linear.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/linear.py
@@ -1,5 +1,7 @@
from typing import Tuple
+
import torch
+
from ..registry import meta_profiler_module
diff --git a/colossalai/fx/profiler/experimental/profiler_module/normalization.py b/colossalai/fx/profiler/experimental/profiler_module/normalization.py
index 49e5e6fa5384b07412abe7ecc947d7963e88bd1a..5a64e44947b748b8b8ff7554dcef54d9933fd016 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/normalization.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/normalization.py
@@ -16,8 +16,12 @@ from ..registry import meta_profiler_module
@meta_profiler_module.register(torch.nn.BatchNorm1d)
@meta_profiler_module.register(torch.nn.BatchNorm2d)
@meta_profiler_module.register(torch.nn.BatchNorm3d)
-def torch_nn_normalize(self: Union[torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d,
- torch.nn.BatchNorm3d], input: torch.Tensor) -> Tuple[int, int]:
+def torch_nn_normalize(
+ self: Union[
+ torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d
+ ],
+ input: torch.Tensor,
+) -> Tuple[int, int]:
# adopted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L615
has_affine = self.weight is not None
if self.training:
@@ -30,6 +34,7 @@ def torch_nn_normalize(self: Union[torch.nn.LayerNorm, torch.nn.GroupNorm, torch
try:
import apex
+
meta_profiler_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize)
meta_profiler_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize)
meta_profiler_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize)
diff --git a/colossalai/fx/profiler/experimental/profiler_module/pooling.py b/colossalai/fx/profiler/experimental/profiler_module/pooling.py
index e429ac3eea28055f42af2ea8f84663a5a6fd2a83..b3b630b2dee91f16358b9102963287cb81883bd5 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/pooling.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/pooling.py
@@ -1,5 +1,7 @@
from typing import Tuple
+
import torch
+
from ..registry import meta_profiler_module
diff --git a/colossalai/fx/profiler/experimental/profiler_module/rnn.py b/colossalai/fx/profiler/experimental/profiler_module/rnn.py
index 6e733d6da9156db13b2bac35af63a52ad89ad5a3..8a4c828dbd276dc6cab6d2e36209a16e4b48230a 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/rnn.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/rnn.py
@@ -1,12 +1,15 @@
-from functools import reduce
import operator
+from functools import reduce
+from typing import Optional, Tuple
+
import torch
+
from ..registry import meta_profiler_module
-from typing import Optional, Tuple, Union
-def _rnn_flops(flops: int, macs: int, module: torch.nn.RNNBase, w_ih: torch.Tensor,
- w_hh: torch.Tensor) -> Tuple[int, int]:
+def _rnn_flops(
+ flops: int, macs: int, module: torch.nn.RNNBase, w_ih: torch.Tensor, w_hh: torch.Tensor
+) -> Tuple[int, int]:
# copied from https://github.com/sovrasov/flops-counter.pytorch/blob/master/ptflops/pytorch_ops.py
# matrix matrix mult ih state and internal state
@@ -42,12 +45,12 @@ def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch
flops = 0
macs = 0
for i in range(self.num_layers):
- w_ih = self.__getattr__('weight_ih_l' + str(i))
- w_hh = self.__getattr__('weight_hh_l' + str(i))
+ w_ih = self.__getattr__("weight_ih_l" + str(i))
+ w_hh = self.__getattr__("weight_hh_l" + str(i))
flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh)
if self.bias:
- b_ih = self.__getattr__('bias_ih_l' + str(i))
- b_hh = self.__getattr__('bias_hh_l' + str(i))
+ b_ih = self.__getattr__("bias_ih_l" + str(i))
+ b_hh = self.__getattr__("bias_hh_l" + str(i))
flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh)
flops *= reduce(operator.mul, input.shape[:2])
macs *= reduce(operator.mul, input.shape[:2])
@@ -63,12 +66,12 @@ def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch
def torch_nn_rnn(self: torch.nn.RNNCellBase, input: torch.Tensor, hx: Optional[torch.Tensor] = None) -> Tuple[int, int]:
flops = 0
macs = 0
- w_ih = self.__getattr__('weight_ih_l')
- w_hh = self.__getattr__('weight_hh_l')
+ w_ih = self.__getattr__("weight_ih_l")
+ w_hh = self.__getattr__("weight_hh_l")
flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh)
if self.bias:
- b_ih = self.__getattr__('bias_ih_l')
- b_hh = self.__getattr__('bias_hh_l')
+ b_ih = self.__getattr__("bias_ih_l")
+ b_hh = self.__getattr__("bias_hh_l")
flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh)
flops *= input.shape[0]
macs *= input.shape[0]
diff --git a/colossalai/fx/profiler/experimental/profiler_module/torch_op.py b/colossalai/fx/profiler/experimental/profiler_module/torch_op.py
index d3aed874eb10af76dc94e21b23c566178afe6264..06be25246a712ae6af9596b5f398d9887b8f42bb 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/torch_op.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/torch_op.py
@@ -1,7 +1,8 @@
-import operator
+from typing import Tuple
+
import torch
+
from ..registry import meta_profiler_module
-from typing import Optional, Tuple, Union
@meta_profiler_module.register(torch.nn.Flatten)
diff --git a/colossalai/fx/profiler/experimental/registry.py b/colossalai/fx/profiler/experimental/registry.py
index 7d73bce321e43d7c1284bf4d78dbac2bc7c4abfc..d47129cd2978422b8d669a61cbe684afe399e337 100644
--- a/colossalai/fx/profiler/experimental/registry.py
+++ b/colossalai/fx/profiler/experimental/registry.py
@@ -1,11 +1,9 @@
class ProfilerRegistry:
-
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
-
def wrapper(func):
self.store[source] = func
return func
@@ -21,5 +19,5 @@ class ProfilerRegistry:
return source in self.store
-meta_profiler_function = ProfilerRegistry(name='patched_functions_for_meta_profile')
-meta_profiler_module = ProfilerRegistry(name='patched_modules_for_meta_profile')
+meta_profiler_function = ProfilerRegistry(name="patched_functions_for_meta_profile")
+meta_profiler_module = ProfilerRegistry(name="patched_modules_for_meta_profile")
diff --git a/colossalai/fx/profiler/experimental/shard_utils.py b/colossalai/fx/profiler/experimental/shard_utils.py
index 1e53ed0bf8ec657d8916d49c8c4b97f1d996010a..90e8c3b7cfe46be47f29198efaf1f9a21a2fd174 100644
--- a/colossalai/fx/profiler/experimental/shard_utils.py
+++ b/colossalai/fx/profiler/experimental/shard_utils.py
@@ -1,8 +1,6 @@
# for PyTorch 1.11 compatibility uses
-from typing import Dict, List, Tuple, Union
-import torch
-from torch.fx import GraphModule, Node
+from torch.fx import Node
from ..._compatibility import compatibility
@@ -19,7 +17,7 @@ def calculate_fwd_in(n: Node) -> bool:
Returns:
save_fwd_in (bool): the result of `save_fwd_in`
"""
- return n.meta['save_fwd_in']
+ return n.meta["save_fwd_in"]
@compatibility(is_backward_compatible=True)
@@ -45,4 +43,4 @@ def calculate_fwd_out(n: Node) -> int:
Returns:
fwd_out (int): the result of `fwd_out`
"""
- return n.meta['fwd_mem_out']
+ return n.meta["fwd_mem_out"]
diff --git a/colossalai/fx/profiler/memory_utils.py b/colossalai/fx/profiler/memory_utils.py
index 6ccbcb01cdc14045fbdd4906f1fc6f2a5ad728db..e8eb5f25cb6c663936125ff38fed69bbddbf8ae4 100644
--- a/colossalai/fx/profiler/memory_utils.py
+++ b/colossalai/fx/profiler/memory_utils.py
@@ -1,11 +1,11 @@
from typing import Dict, List, Tuple, Union
import torch
-from torch.fx import GraphModule, Node
+from torch.fx import Node
from .._compatibility import compatibility, is_compatible_with_meta
-__all__ = ['activation_size', 'parameter_size', 'is_inplace']
+__all__ = ["activation_size", "parameter_size", "is_inplace"]
@compatibility(is_backward_compatible=True)
@@ -63,6 +63,7 @@ def is_inplace(n: Node):
inplace = n.kwargs.get("inplace", False)
if is_compatible_with_meta():
from .constants import ALIAS_ATEN
+
if n.target in ALIAS_ATEN:
inplace = True
elif n.op == "call_module":
diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py
index ba090a2ec51bd7d1d83a6dd5d75c877c0708577f..8fae0f2ecb45a4a2aa12449248f7a451a6907b01 100644
--- a/colossalai/fx/profiler/opcount.py
+++ b/colossalai/fx/profiler/opcount.py
@@ -173,8 +173,11 @@ def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable:
# Inputs[0] contains the shape of the input.
input_shape = inputs[input_arg_index].shape
- has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index],
- 'shape') else inputs[affine_arg_index]
+ has_affine = (
+ inputs[affine_arg_index].shape is not None
+ if hasattr(inputs[affine_arg_index], "shape")
+ else inputs[affine_arg_index]
+ )
assert 2 <= len(input_shape) <= 5, input_shape
# 5 is just a rough estimate
flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4)
@@ -188,7 +191,7 @@ def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = N
training = inputs[-3]
assert isinstance(training, bool), "Signature of aten::batch_norm has changed!"
if training:
- return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
+ return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
has_affine = inputs[1].shape is not None
input_shape = reduce(operator.mul, inputs[0].shape)
return input_shape * (2 if has_affine else 1)
@@ -218,15 +221,16 @@ def elementwise_flop_counter(input_scale: float = 1, output_scale: float = 0) ->
def zero_flop_jit(*args):
"""
- Count flops for zero flop layers.
+ Count flops for zero flop layers.
"""
return 0
-if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
- torch.__version__) < version.parse('2.0.0'):
+if version.parse(torch.__version__) >= version.parse("1.12.0") and version.parse(torch.__version__) < version.parse(
+ "2.0.0"
+):
flop_mapping = {
- # gemm, gemv and dot
+ # gemm, gemv and dot
aten.mm.default: matmul_flop_jit,
aten.mv.default: matmul_flop_jit,
aten.dot.default: matmul_flop_jit,
@@ -234,13 +238,11 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse
aten.addmm.default: addmm_flop_jit,
aten.bmm.default: bmm_flop_jit,
aten.baddbmm.default: baddbmm_flop_jit,
-
- # convolution
+ # convolution
aten.convolution.default: conv_flop_jit,
aten._convolution.default: conv_flop_jit,
aten.convolution_backward.default: conv_backward_flop_jit,
-
- # normalization
+ # normalization
aten.native_batch_norm.default: batchnorm_flop_jit,
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
@@ -249,8 +251,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
aten.native_group_norm.default: norm_flop_counter(2, 0),
aten.native_group_norm_backward.default: norm_flop_counter(2, 0),
-
- # pooling
+ # pooling
aten.avg_pool1d.default: elementwise_flop_counter(1, 0),
aten.avg_pool2d.default: elementwise_flop_counter(1, 0),
aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1),
@@ -275,7 +276,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse
}
elementwise_flop_aten = [
- # basic op
+ # basic op
aten.add.Tensor,
aten.add_.Tensor,
aten.div.Tensor,
@@ -296,8 +297,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse
aten.exp.default,
aten.sin.default,
aten.cos.default,
-
- # activation op
+ # activation op
aten.hardswish.default,
aten.hardswish_.default,
aten.hardswish_backward.default,
@@ -320,8 +320,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse
aten.tanh.default,
aten.tanh_backward.default,
aten.threshold_backward.default,
-
- # dropout
+ # dropout
aten.native_dropout.default,
aten.native_dropout_backward.default,
]
@@ -362,7 +361,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse
aten.zero_.default,
aten.zeros_like.default,
aten.fill_.Scalar,
- aten.stack.default
+ aten.stack.default,
] # yapf: disable
for op in zero_flop_aten:
diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py
index c87cd4321d31c59ebe369a2903b679a439fb4f96..97e70db6290e58d62c113ca37f4323be07ef3e58 100644
--- a/colossalai/fx/profiler/profiler.py
+++ b/colossalai/fx/profiler/profiler.py
@@ -15,7 +15,7 @@ from .memory_utils import activation_size, parameter_size
from .opcount import flop_mapping
from .tensor import MetaTensor
-__all__ = ['profile_function', 'profile_module', 'profile_method']
+__all__ = ["profile_function", "profile_module", "profile_method"]
# super-dainiu: this cache should be global, otherwise it cannot
# track duplicated tensors between nodes
@@ -174,7 +174,6 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
# backward is executed.
# Hopefully, this attempt will provide a better estimation of memory.
class FlopTensor(MetaTensor):
-
_node: Node = None
def __repr__(self):
@@ -186,24 +185,24 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
args_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, args)
kwargs_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, kwargs)
- node = subgraph.create_node('call_function', func, args_node, kwargs_node)
+ node = subgraph.create_node("call_function", func, args_node, kwargs_node)
out = super().__torch_dispatch__(func, types, args, kwargs)
flop_count[phase] += flop_mapping[func](args, normalize_tuple(out))
- node.meta['phase'] = phase
+ node.meta["phase"] = phase
# super-dainiu: in `nn.MultiheadAttention` this weird thing occurs,
# i.e. `Phase.PLACEHOLDER` tensors are aliased and saved during
# `Phase.FORWARD`
if phase == Phase.FORWARD:
if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN:
- node.meta['phase'] = Phase.PLACEHOLDER
+ node.meta["phase"] = Phase.PLACEHOLDER
# TODO(yby): specify `saved_tensors` for backward memory estimation
- node.meta['saved_tensor'] = []
+ node.meta["saved_tensor"] = []
if phase == Phase.BACKWARD:
- node.meta['saved_tensor'] = normalize_tuple(out)
+ node.meta["saved_tensor"] = normalize_tuple(out)
def wrap(x):
if isinstance(x, MetaTensor):
@@ -219,11 +218,14 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
x = FlopTensor(x)
if is_autogradable(x):
x.requires_grad_(True)
- x._node = subgraph.create_node('placeholder',
- 'placeholder', (subgraph._root,),
- name=subgraph._graph_namespace.create_name('input', x._tensor))
- x._node.meta['phase'] = Phase.PLACEHOLDER
- x._node.meta['saved_tensor'] = []
+ x._node = subgraph.create_node(
+ "placeholder",
+ "placeholder",
+ (subgraph._root,),
+ name=subgraph._graph_namespace.create_name("input", x._tensor),
+ )
+ x._node.meta["phase"] = Phase.PLACEHOLDER
+ x._node.meta["saved_tensor"] = []
return x
# Basically, we need to detach the args and kwargs from the outer graph.
@@ -235,7 +237,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
if isinstance(x, FlopTensor) and not x._tensor.data_ptr() in cache:
tensor = x._tensor.detach()
tensor.data_ptr = x._tensor.data_ptr
- x._node.meta['saved_tensor'] += [tensor]
+ x._node.meta["saved_tensor"] += [tensor]
if not do_not_cache:
cache.add(x._tensor.data_ptr())
return x
@@ -284,7 +286,7 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G
@compatibility(is_backward_compatible=True)
-def profile_function(target: 'Target', device: str = 'meta') -> Callable:
+def profile_function(target: "Target", device: str = "meta") -> Callable:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution.
@@ -300,7 +302,6 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
-
# find the grad for parameter in args and kwargs
param_size = 0
@@ -316,18 +317,18 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
# still run the profiling but discard some results regarding `target`
global do_not_cache
- inplace = kwargs.get('inplace', False)
+ inplace = kwargs.get("inplace", False)
if target in OUTPUT_SAVED_OPS:
do_not_cache = True
if inplace:
do_not_cache = True
- kwargs['inplace'] = False
- if device == 'meta':
+ kwargs["inplace"] = False
+ if device == "meta":
out, meta = _profile_meta(func, *args, **kwargs)
else:
out, meta = _profile_concrete(func, *args, **kwargs)
if inplace:
- kwargs['inplace'] = True
+ kwargs["inplace"] = True
meta.bwd_mem_tmp = 0
meta.bwd_mem_out = 0
do_not_cache = False
@@ -341,7 +342,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
@compatibility(is_backward_compatible=True)
-def profile_method(target: 'Target', device: str = 'meta') -> Callable:
+def profile_method(target: "Target", device: str = "meta") -> Callable:
"""
Wrap a `call_method` node
record the memory cost and FLOPs of the execution.
@@ -349,8 +350,8 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# execute the method and return the result
- assert isinstance(target, str), f'{target} instance is not str.'
- if device == 'meta':
+ assert isinstance(target, str), f"{target} instance is not str."
+ if device == "meta":
out, meta = _profile_meta(target, *args, **kwargs)
else:
out, meta = _profile_concrete(target, *args, **kwargs)
@@ -360,7 +361,7 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable:
@compatibility(is_backward_compatible=True)
-def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
+def profile_module(module: torch.nn.Module, device: str = "meta") -> Callable:
"""
Wrap a `call_module` node or `torch.nn` in order to
record the memory cost and FLOPs of the execution.
@@ -376,7 +377,6 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
-
# calculate parameter size
param_size = parameter_size(module)
@@ -384,13 +384,13 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable:
# still run the profiling but discard some results regarding `module`.
global do_not_cache
- inplace = getattr(module, 'inplace', False)
+ inplace = getattr(module, "inplace", False)
if type(module) in OUTPUT_SAVED_MOD:
do_not_cache = True
if inplace:
do_not_cache = True
module.inplace = False
- if device == 'meta':
+ if device == "meta":
out, meta = _profile_meta(func, *args, **kwargs)
else:
out, meta = _profile_concrete(func, *args, **kwargs)
diff --git a/colossalai/fx/profiler/shard_utils.py b/colossalai/fx/profiler/shard_utils.py
index 34feefb4336ab4a7924f7023b6f887ba8610b25c..75b7c814f05f4fd6f170959da4fb5d451bfcc75a 100644
--- a/colossalai/fx/profiler/shard_utils.py
+++ b/colossalai/fx/profiler/shard_utils.py
@@ -59,9 +59,9 @@ def calculate_fwd_tmp(n: Node) -> int:
Returns:
bool: Whether the node is a ReLU-like node
"""
- if n.op == 'call_function':
+ if n.op == "call_function":
return n.target in OUTPUT_SAVED_OPS
- elif n.op == 'call_module':
+ elif n.op == "call_module":
return type(n.graph.owning_module.get_submodule(n.target)) in OUTPUT_SAVED_MOD
return False
diff --git a/colossalai/fx/profiler/tensor.py b/colossalai/fx/profiler/tensor.py
index 2ee5e5c47750c4bbc714987ce4bcae9f29e6ac71..7c14b48bdaa1315e28494cb5b1cbbe8920921a53 100644
--- a/colossalai/fx/profiler/tensor.py
+++ b/colossalai/fx/profiler/tensor.py
@@ -1,13 +1,13 @@
import uuid
import torch
-from torch.types import _bool, _device, _dtype
-from torch.utils._pytree import tree_flatten, tree_map
+from torch.types import _device
+from torch.utils._pytree import tree_map
from .._compatibility import compatibility
from .constants import ALIAS_ATEN
-__all__ = ['MetaTensor']
+__all__ = ["MetaTensor"]
def set_data_ptr(x):
@@ -43,12 +43,13 @@ class MetaTensor(torch.Tensor):
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
- device=fake_device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')),
- requires_grad=elem.requires_grad) # deceive the frontend for aten selections
+ device=fake_device or (elem.device if elem.device.type != "meta" else torch.device("cpu")),
+ requires_grad=elem.requires_grad,
+ ) # deceive the frontend for aten selections
r._tensor = elem
# ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
- r._tensor = r._tensor.to(torch.device('meta'))
+ r._tensor = r._tensor.to(torch.device("meta"))
# only tensor not on `meta` should be copied to `meta`
set_data_ptr(r._tensor)
return r
@@ -69,15 +70,15 @@ class MetaTensor(torch.Tensor):
x = x._tensor
elif isinstance(x, torch.Tensor):
fake_device = x.device
- x = x.to(torch.device('meta'))
+ x = x.to(torch.device("meta"))
return x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
- if 'device' in kwargs:
- fake_device = kwargs['device']
- kwargs['device'] = torch.device('meta')
+ if "device" in kwargs:
+ fake_device = kwargs["device"]
+ kwargs["device"] = torch.device("meta")
# run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs)
@@ -93,7 +94,7 @@ class MetaTensor(torch.Tensor):
if isinstance(x, torch.Tensor):
nonlocal fake_device
if not x.is_meta:
- x = x.to(torch.device('meta'))
+ x = x.to(torch.device("meta"))
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
return tree_map(wrap, out)
@@ -120,18 +121,18 @@ class MetaTensor(torch.Tensor):
nonlocal fake_device
if isinstance(x, str) or isinstance(x, _device):
fake_device = x
- return 'meta'
+ return "meta"
return x
elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
return MetaTensor(elem, fake_device=fake_device)
def cpu(self, *args, **kwargs):
- if self.device.type == 'cpu':
+ if self.device.type == "cpu":
return self.to(*args, **kwargs)
- return self.to(*args, device='cpu', **kwargs)
+ return self.to(*args, device="cpu", **kwargs)
def cuda(self, device=None, non_blocking=False):
if device is not None:
return self.to(device=device, non_blocking=non_blocking)
- return self.to(device='cuda:0', non_blocking=non_blocking)
+ return self.to(device="cuda:0", non_blocking=non_blocking)
diff --git a/colossalai/fx/proxy.py b/colossalai/fx/proxy.py
index 7317072c6298b66280810c48536eb22b7edca7f0..887832223fd6cf263e26f4bdbb43b9a39f194302 100644
--- a/colossalai/fx/proxy.py
+++ b/colossalai/fx/proxy.py
@@ -1,12 +1,11 @@
-import operator
-from typing import Any, List, Union
+from typing import Any
import torch
-from torch.fx.proxy import Attribute, Proxy
+from torch.fx.proxy import Proxy
from colossalai.fx.tracer.meta_patch import meta_patched_function
-__all__ = ['ColoProxy']
+__all__ = ["ColoProxy"]
class ColoProxy(Proxy):
@@ -39,11 +38,12 @@ class ColoProxy(Proxy):
return self._meta_data is not None
def _assert_meta_data_is_tensor(self):
- assert torch.is_tensor(
- self._meta_data) and self._meta_data.is_meta, f'Meta data is not a meta tensor for {self.node.name}'
+ assert (
+ torch.is_tensor(self._meta_data) and self._meta_data.is_meta
+ ), f"Meta data is not a meta tensor for {self.node.name}"
def _assert_has_meta_data(self):
- assert self._meta_data is not None, f'Meta data is not set for {self.node.name}'
+ assert self._meta_data is not None, f"Meta data is not set for {self.node.name}"
def __len__(self):
self._assert_has_meta_data()
@@ -62,7 +62,6 @@ class ColoProxy(Proxy):
return self.meta_data
def __getattr__(self, k):
-
return ColoAttribute(self, k)
def __contains__(self, key):
@@ -92,7 +91,6 @@ def extract_meta(*args, **kwargs):
class ColoAttribute(ColoProxy):
-
def __init__(self, root, attr: str):
self.root = root
self.attr = attr
diff --git a/colossalai/fx/tracer/_meta_trace.py b/colossalai/fx/tracer/_meta_trace.py
index 1c5abb81d271144ab666049d3ae2868cd9568497..63a7bab654d5d821a0541976a0f86368a6bae419 100644
--- a/colossalai/fx/tracer/_meta_trace.py
+++ b/colossalai/fx/tracer/_meta_trace.py
@@ -39,7 +39,7 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr
_tensor: torch.Tensor
_node: Node
- __slots__ = ['_tensor', '_node']
+ __slots__ = ["_tensor", "_node"]
@staticmethod
def __new__(cls, tensor, fake_device=None, placeholder=False, name=None):
@@ -51,22 +51,22 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr
dtype=tensor.dtype,
layout=tensor.layout,
device=fake_device if fake_device is not None else tensor.device,
- requires_grad=tensor.requires_grad) # deceive the frontend for aten selections
+ requires_grad=tensor.requires_grad,
+ ) # deceive the frontend for aten selections
r._tensor = tensor
if placeholder:
if name is None:
- name = 'input'
- r._node = graph.create_node('placeholder',
- 'placeholder', (graph._root,),
- name=namespace.create_name(name, tensor))
+ name = "input"
+ r._node = graph.create_node(
+ "placeholder", "placeholder", (graph._root,), name=namespace.create_name(name, tensor)
+ )
# ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
- r._tensor = r._tensor.to(torch.device('meta'))
+ r._tensor = r._tensor.to(torch.device("meta"))
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
-
def unwrap(x):
nonlocal fake_device
if isinstance(x, MetaProxy):
@@ -75,21 +75,21 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr
# assert not isinstance(x, MetaProxy)
elif isinstance(x, torch.Tensor):
fake_device = x.device
- x = x.to(torch.device('meta'))
+ x = x.to(torch.device("meta"))
return x
def get_node(x):
- if isinstance(x, torch.Tensor) and not hasattr(x, '_node'):
- x = MetaProxy(x, placeholder=True, name='weight')
- return x if not hasattr(x, '_node') else x._node
+ if isinstance(x, torch.Tensor) and not hasattr(x, "_node"):
+ x = MetaProxy(x, placeholder=True, name="weight")
+ return x if not hasattr(x, "_node") else x._node
args_node = tree_map(get_node, args)
kwargs_node = tree_map(get_node, kwargs)
- node = graph.create_node('call_function', func, args_node, kwargs_node)
+ node = graph.create_node("call_function", func, args_node, kwargs_node)
- if 'device' in kwargs:
- fake_device = kwargs['device']
- kwargs['device'] = torch.device('meta')
+ if "device" in kwargs:
+ fake_device = kwargs["device"]
+ kwargs["device"] = torch.device("meta")
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
@@ -103,9 +103,12 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr
if isinstance(x, torch.Tensor):
nonlocal fake_device
if not x.is_meta:
- x = x.to(torch.device('meta'))
- return MetaProxy(
- x, fake_device=fake_device) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x
+ x = x.to(torch.device("meta"))
+ return (
+ MetaProxy(x, fake_device=fake_device)
+ if isinstance(x, torch.Tensor) and not hasattr(x, "_tensor")
+ else x
+ )
def set_node(x):
x._node = node
@@ -125,9 +128,12 @@ def meta_trace(module: torch.nn.Module, fake_device=None, *args, **kwargs) -> Gr
for tensor in normalize_tuple(out):
if is_autogradable(tensor) and tensor.requires_grad:
- grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance(
- tensor, MetaProxy) else torch.empty_like(tensor, device=torch.device('meta'))
- torch.autograd.backward(tensor,
- MetaProxy(grad, fake_device=tensor.device, placeholder=True),
- retain_graph=True)
+ grad = (
+ torch.empty_like(tensor._tensor, device=torch.device("meta"))
+ if isinstance(tensor, MetaProxy)
+ else torch.empty_like(tensor, device=torch.device("meta"))
+ )
+ torch.autograd.backward(
+ tensor, MetaProxy(grad, fake_device=tensor.device, placeholder=True), retain_graph=True
+ )
return graph
diff --git a/colossalai/fx/tracer/_tracer_utils.py b/colossalai/fx/tracer/_tracer_utils.py
index e160497a74444fd34eea50943eb0430875fe1252..9cf1961d45ff96c987ebe81bb0897581be32947c 100644
--- a/colossalai/fx/tracer/_tracer_utils.py
+++ b/colossalai/fx/tracer/_tracer_utils.py
@@ -2,10 +2,10 @@ from typing import Any, List, Union
import torch
-from ..proxy import ColoAttribute, ColoProxy
-from .meta_patch import meta_patched_function, meta_patched_module
+from ..proxy import ColoProxy
+from .meta_patch import meta_patched_function
-__all__ = ['is_element_in_list', 'extract_meta']
+__all__ = ["is_element_in_list", "extract_meta"]
def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]):
@@ -21,7 +21,6 @@ def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]):
def extract_meta(*args, **kwargs):
-
def _convert(val):
if isinstance(val, ColoProxy):
return val.meta_data
diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py
index 859a19bf6241bbbf4061e0f7564975682527b8c2..84c09109877ece5c129bdfd862d160631cec7719 100644
--- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py
+++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py
@@ -1,7 +1,4 @@
-import operator
-
import torch
-import torch.nn.functional as F
from ...registry import bias_addition_function, bias_addition_method
from .bias_addition_function import LinearBasedBiasFunc
@@ -10,13 +7,12 @@ from .bias_addition_function import LinearBasedBiasFunc
@bias_addition_method.register(torch.Tensor.addbmm)
@bias_addition_function.register(torch.addbmm)
class Addbmm(LinearBasedBiasFunc):
-
def extract_kwargs_from_origin_func(self):
kwargs = {}
- if 'beta' in self.kwargs:
- kwargs['beta'] = self.kwargs['beta']
- if 'alpha' in self.kwargs:
- kwargs['alpha'] = self.kwargs['alpha']
+ if "beta" in self.kwargs:
+ kwargs["beta"] = self.kwargs["beta"]
+ if "alpha" in self.kwargs:
+ kwargs["alpha"] = self.kwargs["alpha"]
return kwargs
def create_non_bias_func_proxy(self, input_proxy, other_proxy):
@@ -25,7 +21,7 @@ class Addbmm(LinearBasedBiasFunc):
compute the main computation, such as convolution, with bias option banned.
"""
assert self.substitute_func == torch.bmm
- node_kind = 'call_function'
+ node_kind = "call_function"
node_target = self.substitute_func
node_args = (input_proxy, other_proxy)
@@ -35,10 +31,10 @@ class Addbmm(LinearBasedBiasFunc):
return non_bias_func_proxy
def insert_sum_node(self, input_proxy, sum_dims=0):
- '''
+ """
This method is used to sum the input_proxy through the sum_dims.
- '''
- node_kind = 'call_function'
+ """
+ node_kind = "call_function"
node_target = torch.sum
node_args = (input_proxy, sum_dims)
node_kwargs = {}
@@ -55,15 +51,15 @@ class Addbmm(LinearBasedBiasFunc):
sum_proxy = self.insert_sum_node(non_bias_linear_func_proxy)
kwargs = self.extract_kwargs_from_origin_func()
- if 'beta' in kwargs:
- beta = kwargs['beta']
+ if "beta" in kwargs:
+ beta = kwargs["beta"]
# doing the multiplication with beta if it exists(temp_2 = beta * input)
beta_proxy = self.create_mul_node(self.args[0], beta)
else:
beta_proxy = self.args[0]
- if 'alpha' in kwargs:
- alpha = kwargs['alpha']
+ if "alpha" in kwargs:
+ alpha = kwargs["alpha"]
# doing the multiplication with alpha if it exists(temp_3 = alpha * temp_1)
alpha_proxy = self.create_mul_node(alpha, sum_proxy)
else:
diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py
index fe7d8d07aac941028d7c682043b5af2bcdf2537a..d087b291300594b1fdc15be4c8d6aad35aba023c 100644
--- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py
+++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py
@@ -1,7 +1,4 @@
-import operator
-
import torch
-import torch.nn.functional as F
from ...registry import bias_addition_function, bias_addition_method
from .bias_addition_function import LinearBasedBiasFunc
@@ -10,17 +7,16 @@ from .bias_addition_function import LinearBasedBiasFunc
@bias_addition_method.register(torch.Tensor.addmm)
@bias_addition_function.register(torch.addmm)
class Addmm(LinearBasedBiasFunc):
-
def extract_kwargs_from_origin_func(self):
kwargs = {}
- if 'beta' in self.kwargs:
- kwargs['beta'] = self.kwargs['beta']
- if 'alpha' in self.kwargs:
- kwargs['alpha'] = self.kwargs['alpha']
+ if "beta" in self.kwargs:
+ kwargs["beta"] = self.kwargs["beta"]
+ if "alpha" in self.kwargs:
+ kwargs["alpha"] = self.kwargs["alpha"]
return kwargs
def transpose_other_operand_for_linear(self, other_proxy):
- '''
+ """
This method is used to transpose the other operand for linear function.
For example:
input = torch.rand(3, 4)
@@ -30,8 +26,8 @@ class Addmm(LinearBasedBiasFunc):
# To keep the computation graph consistent with the origin computation graph, we need to transpose the m2
# before we call the linear function.
new_output = torch.linear(m1, m2.transpose(0, 1)) + input
- '''
- node_kind = 'call_function'
+ """
+ node_kind = "call_function"
node_target = torch.transpose
node_args = (other_proxy, 0, 1)
node_kwargs = {}
@@ -43,14 +39,14 @@ class Addmm(LinearBasedBiasFunc):
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], transpose_proxy)
kwargs = self.extract_kwargs_from_origin_func()
- if 'beta' in kwargs:
- beta = kwargs['beta']
+ if "beta" in kwargs:
+ beta = kwargs["beta"]
beta_proxy = self.create_mul_node(self.args[0], beta)
else:
beta_proxy = self.args[0]
- if 'alpha' in kwargs:
- alpha = kwargs['alpha']
+ if "alpha" in kwargs:
+ alpha = kwargs["alpha"]
alpha_proxy = self.create_mul_node(alpha, non_bias_linear_func_proxy)
else:
alpha_proxy = non_bias_linear_func_proxy
diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py
index 8a3786332c08d9a3320ce4c7bee8221dd3d10abd..42178b7b786e95493893a6464aedba7b10e9b34e 100644
--- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py
+++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py
@@ -29,7 +29,6 @@ class BiasAdditionFunc(ABC):
to insert two more operator.mul nodes for the computation graph to compute the
final result.
"""
- pass
@abstractmethod
def generate(self):
@@ -50,7 +49,6 @@ class BiasAdditionFunc(ABC):
%mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {})
%add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {})
"""
- pass
def create_mul_node(self, input_proxy, coefficent):
"""
@@ -59,7 +57,7 @@ class BiasAdditionFunc(ABC):
Therefore, we need to use this method insert two more operator.mul nodes for
the computation graph to compute the final result.
"""
- node_kind = 'call_function'
+ node_kind = "call_function"
node_target = operator.mul
node_args = (
input_proxy,
@@ -82,7 +80,7 @@ class LinearBasedBiasFunc(BiasAdditionFunc):
compute the main computation, such as convolution, with bias option banned.
"""
assert self.substitute_func == torch.nn.functional.linear
- node_kind = 'call_function'
+ node_kind = "call_function"
node_target = self.substitute_func
node_args = (input_proxy, other_proxy)
@@ -96,7 +94,7 @@ class LinearBasedBiasFunc(BiasAdditionFunc):
This method is used to create the bias_addition_proxy, the node created by this proxy will
compute the sum of non_bias_func result and bias with some reshape operation if needed.
"""
- bias_add_node_kind = 'call_function'
+ bias_add_node_kind = "call_function"
bias_add_node_target = operator.add
bias_add_args = (non_bias_func_proxy, bias_proxy)
bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {})
diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py
index e11ec0a364f1e5ee1445c97ae0c9b054d02bcfa2..ed060a350739ad31658264807f060595b0c5e217 100644
--- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py
+++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py
@@ -1,6 +1,3 @@
-import operator
-
-import torch
import torch.nn.functional as F
from ...registry import bias_addition_function
@@ -9,17 +6,16 @@ from .bias_addition_function import LinearBasedBiasFunc
@bias_addition_function.register(F.linear)
class Linear(LinearBasedBiasFunc):
-
def extract_kwargs_from_origin_func(self):
- assert 'bias' in self.kwargs
+ assert "bias" in self.kwargs
kwargs = {}
- if 'bias' in self.kwargs:
- kwargs['bias'] = self.kwargs['bias']
+ if "bias" in self.kwargs:
+ kwargs["bias"] = self.kwargs["bias"]
return kwargs
def generate(self):
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[0], self.args[1])
kwargs = self.extract_kwargs_from_origin_func()
- bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs['bias'])
+ bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs["bias"])
return bias_addition_proxy
diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py
index 85f1553e304c9c45b2b8f1373e76e7023d452d47..19c0e21d7c17637f4ad7956ce62e35f46328930d 100644
--- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py
+++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py
@@ -27,8 +27,8 @@ class BiasAdditionModule(ABC):
Note: this function will be invoked during module initializing,
you should never call this function.
"""
- weight_node_kind = 'get_attr'
- weight_node_target = self.target + '.weight'
+ weight_node_kind = "get_attr"
+ weight_node_target = self.target + ".weight"
weight_proxy = self.tracer.create_proxy(weight_node_kind, weight_node_target, (), {})
return weight_proxy
@@ -39,8 +39,8 @@ class BiasAdditionModule(ABC):
Note: this function will be invoked during module initializing,
you should never call this function.
"""
- bias_node_kind = 'get_attr'
- bias_node_target = self.target + '.bias'
+ bias_node_kind = "get_attr"
+ bias_node_target = self.target + ".bias"
bias_proxy = self.tracer.create_proxy(bias_node_kind, bias_node_target, (), {})
return bias_proxy
@@ -51,17 +51,16 @@ class BiasAdditionModule(ABC):
For example:
The kwargs for conv2d module is {} because the attributes like 'padding' or 'groups' are
- considered during module initilizing. However, we need to consider those attributes as kwargs
+ considered during module initializing. However, we need to consider those attributes as kwargs
in F.conv2d.
"""
- pass
def create_non_bias_func_proxy(self, input_proxy=None):
"""
This method is used to create the non_bias_func proxy, the node created by this proxy will
compute the main computation, such as convolution, with bias option banned.
"""
- node_kind = 'call_function'
+ node_kind = "call_function"
node_target = self.substitute_func
if input_proxy is None:
input_proxy = self.args[0]
@@ -75,7 +74,7 @@ class BiasAdditionModule(ABC):
This method is used to create the bias_addition_proxy, the node created by this proxy will
compute the sum of non_bias_func result and bias with some reshape operation if needed.
"""
- bias_add_node_kind = 'call_function'
+ bias_add_node_kind = "call_function"
bias_add_node_target = operator.add
bias_add_args = (non_bias_func_proxy, bias_proxy)
bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {})
@@ -100,7 +99,6 @@ class BiasAdditionModule(ABC):
%view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
%add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
"""
- pass
module_to_func_dict = {
diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py
index 4b6c82a74f57d213ba3d8b68863053ddd1aabf9d..812a141c1eab699a780e6513ca3c625cd92f44dd 100644
--- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py
+++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py
@@ -1,6 +1,5 @@
import torch
-import torch.nn.functional as F
-from torch.nn.modules.utils import _pair, _reverse_repeat_tuple, _single, _triple
+from torch.nn.modules.utils import _pair, _single, _triple
from ...registry import bias_addition_module
from .bias_addition_module import BiasAdditionModule
@@ -10,17 +9,16 @@ from .bias_addition_module import BiasAdditionModule
@bias_addition_module.register(torch.nn.Conv2d)
@bias_addition_module.register(torch.nn.Conv3d)
class BiasAdditionConv(BiasAdditionModule):
-
def extract_kwargs_from_mod(self):
root = self.tracer.root
conv_module = root.get_submodule(self.target)
- kwarg_attributes = ['groups', 'dilation', 'stride']
+ kwarg_attributes = ["groups", "dilation", "stride"]
non_bias_kwargs = {}
for attr_name in kwarg_attributes:
if hasattr(conv_module, attr_name):
non_bias_kwargs[attr_name] = getattr(conv_module, attr_name)
if conv_module.padding_mode != "zeros":
- #TODO: non zeros mode requires some extra processing for input
+ # TODO: non zeros mode requires some extra processing for input
conv_type = type(conv_module)
if conv_type == "torch.nn.Conv1d":
padding_element = _single(0)
@@ -28,9 +26,9 @@ class BiasAdditionConv(BiasAdditionModule):
padding_element = _pair(0)
elif conv_type == "torch.nn.Conv3d":
padding_element = _triple(0)
- non_bias_kwargs['padding'] = padding_element
+ non_bias_kwargs["padding"] = padding_element
else:
- non_bias_kwargs['padding'] = getattr(conv_module, 'padding')
+ non_bias_kwargs["padding"] = getattr(conv_module, "padding")
return non_bias_kwargs
@@ -41,11 +39,12 @@ class BiasAdditionConv(BiasAdditionModule):
"""
bias_shape = [1] * (dimensions - 1)
bias_shape[0] = -1
- bias_reshape_node_kind = 'call_method'
- bias_reshape_node_target = 'view'
+ bias_reshape_node_kind = "call_method"
+ bias_reshape_node_target = "view"
bias_reshape_node_args = (self.bias_proxy, torch.Size(bias_shape))
- bias_reshape_proxy = self.tracer.create_proxy(bias_reshape_node_kind, bias_reshape_node_target,
- bias_reshape_node_args, {})
+ bias_reshape_proxy = self.tracer.create_proxy(
+ bias_reshape_node_kind, bias_reshape_node_target, bias_reshape_node_args, {}
+ )
return bias_reshape_proxy
def generate(self):
diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py
index f6f7b6ddab401a637aa2b43b4dd8d2ce9193266e..b397f009846c082a8364ac0eeafa4bf085e4cf63 100644
--- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py
+++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py
@@ -1,5 +1,4 @@
import torch
-import torch.nn.functional as F
from ...registry import bias_addition_module
from .bias_addition_module import BiasAdditionModule
@@ -7,7 +6,6 @@ from .bias_addition_module import BiasAdditionModule
@bias_addition_module.register(torch.nn.Linear)
class BiasAdditionLinear(BiasAdditionModule):
-
def extract_kwargs_from_mod(self):
return {}
diff --git a/colossalai/fx/tracer/experimental.py b/colossalai/fx/tracer/experimental.py
index 88b65b6188fa67be33f7a15b55e7fd5d32d7c4cc..e6e511b72fbbc7b0e9d7efd0af25e00935a7bea1 100644
--- a/colossalai/fx/tracer/experimental.py
+++ b/colossalai/fx/tracer/experimental.py
@@ -1,4 +1,3 @@
-import enum
import functools
import inspect
import operator
@@ -10,7 +9,7 @@ from torch.fx import Graph, Node, Proxy, Tracer
from torch.utils._pytree import tree_map
from colossalai.fx import ColoGraphModule, compatibility, is_compatible_with_meta
-from colossalai.fx.tracer._tracer_utils import extract_meta, is_element_in_list
+from colossalai.fx.tracer._tracer_utils import is_element_in_list
from colossalai.fx.tracer.bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict
from colossalai.fx.tracer.registry import (
bias_addition_function,
@@ -24,31 +23,45 @@ if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor
Target = Union[Callable[..., Any], str]
-Argument = Optional[Union[Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types
- List[Any], # actually Argument
- Dict[str, Any], # actually Argument
- slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
- 'Node',]]
-_CScriptMethod = ['add', 'mul', 'sub', 'div']
+Argument = Optional[
+ Union[
+ Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types
+ List[Any], # actually Argument
+ Dict[str, Any], # actually Argument
+ slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
+ "Node",
+ ]
+]
+_CScriptMethod = ["add", "mul", "sub", "div"]
_TorchNewMethod = [
- "arange", "zeros", "zeros_like", "ones", "ones_like", "full", "full_like", "empty", "empty_like", "eye", "tensor",
- "finfo"
+ "arange",
+ "zeros",
+ "zeros_like",
+ "ones",
+ "ones_like",
+ "full",
+ "full_like",
+ "empty",
+ "empty_like",
+ "eye",
+ "tensor",
+ "finfo",
]
_TensorPropertyMethod = ["dtype", "shape", "device", "requires_grad", "grad", "grad_fn", "data"]
def _truncate_suffix(s: str):
import re
- return re.sub(r'_\d+$', '', s)
+
+ return re.sub(r"_\d+$", "", s)
def default_device():
- return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
+ return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
@compatibility(is_backward_compatible=False)
class ColoProxy(Proxy):
-
def __init__(self, *args, data=None, **kwargs):
super().__init__(*args, **kwargs)
self._meta_data = data
@@ -100,7 +113,7 @@ class ColoProxy(Proxy):
return ColoAttribute(self, k, getattr(self._meta_data, k, None))
def __setitem__(self, key, value):
- proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {})
+ proxy = self.tracer.create_proxy("call_function", operator.setitem, (self, key, value), {})
proxy.meta_data = self._meta_data
return proxy
@@ -125,29 +138,28 @@ class ColoProxy(Proxy):
@property
def device(self):
- proxy = self.tracer.create_proxy('call_function', getattr, (self, 'device'), {})
+ proxy = self.tracer.create_proxy("call_function", getattr, (self, "device"), {})
proxy.meta_data = self.meta_data.device
return proxy
@property
def dtype(self):
- proxy = self.tracer.create_proxy('call_function', getattr, (self, 'dtype'), {})
+ proxy = self.tracer.create_proxy("call_function", getattr, (self, "dtype"), {})
proxy.meta_data = self.meta_data.dtype
return proxy
def to(self, *args, **kwargs):
- return self.tracer.create_proxy('call_method', 'to', (self, *args), {**kwargs})
+ return self.tracer.create_proxy("call_method", "to", (self, *args), {**kwargs})
def cpu(self, *args, **kwargs):
- return self.tracer.create_proxy('call_method', 'cpu', (self, *args), {**kwargs})
+ return self.tracer.create_proxy("call_method", "cpu", (self, *args), {**kwargs})
def cuda(self, *args, **kwargs):
- return self.tracer.create_proxy('call_method', 'cuda', (self, *args), {**kwargs})
+ return self.tracer.create_proxy("call_method", "cuda", (self, *args), {**kwargs})
@compatibility(is_backward_compatible=False)
class ColoAttribute(ColoProxy):
-
def __init__(self, root, attr: str, data=None):
self.root = root
self.attr = attr
@@ -160,11 +172,11 @@ class ColoAttribute(ColoProxy):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if self._node is None:
- self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
+ self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
return self._node
def __call__(self, *args, **kwargs):
- return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
+ return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
def __repr__(self):
return f"ColoAttribute({self.node.name}, attr={self.attr})"
@@ -172,7 +184,6 @@ class ColoAttribute(ColoProxy):
@compatibility(is_backward_compatible=False)
class ColoTracer(Tracer):
-
def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs):
super().__init__(*args, **kwargs)
self._disable_module_getattr = False
@@ -184,24 +195,28 @@ class ColoTracer(Tracer):
self.inside_torch_checkpoint_func = False
self.act_ckpt_region_count = 0
- def proxy(self, node: Node) -> 'ColoProxy':
+ def proxy(self, node: Node) -> "ColoProxy":
return ColoProxy(node, self)
- def create_proxy(self,
- kind: str,
- target: Target,
- args: Tuple[Any, ...],
- kwargs: Dict[str, Any],
- name: Optional[str] = None,
- type_expr: Optional[Any] = None,
- proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
-
+ def create_proxy(
+ self,
+ kind: str,
+ target: Target,
+ args: Tuple[Any, ...],
+ kwargs: Dict[str, Any],
+ name: Optional[str] = None,
+ type_expr: Optional[Any] = None,
+ proxy_factory_fn: Callable[[Node], "Proxy"] = None,
+ ):
proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
- if kind == 'placeholder':
- proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get(
- _truncate_suffix(target), None)
- elif kind == 'get_attr':
+ if kind == "placeholder":
+ proxy.meta_data = (
+ self.meta_args[target]
+ if target in self.meta_args
+ else self.concrete_args.get(_truncate_suffix(target), None)
+ )
+ elif kind == "get_attr":
self._disable_module_getattr = True
try:
attr_itr = self.root
@@ -211,20 +226,21 @@ class ColoTracer(Tracer):
proxy.meta_data = attr_itr
finally:
self._disable_module_getattr = False
- elif kind == 'call_function':
+ elif kind == "call_function":
proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
- elif kind == 'call_method':
+ elif kind == "call_method":
self._disable_module_getattr = True
try:
- if target == '__call__':
+ if target == "__call__":
proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
else:
if target not in _TensorPropertyMethod:
- proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
- **tree_map(unwrap_fn, kwargs))
+ proxy._meta_data = getattr(unwrap_fn(args[0]), target)(
+ *tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)
+ )
finally:
self._disable_module_getattr = False
- elif kind == 'call_module':
+ elif kind == "call_module":
mod = self.root.get_submodule(target)
self._disable_module_getattr = True
try:
@@ -238,14 +254,15 @@ class ColoTracer(Tracer):
if self.inside_torch_checkpoint_func:
# annotate the activation checkpoint module
- node.meta['activation_checkpoint'] = self.act_ckpt_region_count
+ node.meta["activation_checkpoint"] = self.act_ckpt_region_count
return node
- def trace(self,
- root: torch.nn.Module,
- concrete_args: Optional[Dict[str, torch.Tensor]] = None,
- meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph:
-
+ def trace(
+ self,
+ root: torch.nn.Module,
+ concrete_args: Optional[Dict[str, torch.Tensor]] = None,
+ meta_args: Optional[Dict[str, torch.Tensor]] = None,
+ ) -> Graph:
if meta_args is None:
meta_args = {}
@@ -260,20 +277,19 @@ class ColoTracer(Tracer):
# update concrete args with default values
non_meta_arg_names = sig_names - meta_arg_names
for k, v in sig.parameters.items():
- if k in non_meta_arg_names and \
- k not in concrete_args and \
- v.default is not inspect.Parameter.empty:
+ if k in non_meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
# get non concrete arg names
concrete_arg_names = set(concrete_args.keys())
- non_concrete_arg_names = sig_names - concrete_arg_names
+ sig_names - concrete_arg_names
def _check_arg_name_valid(names):
success, element = is_element_in_list(names, sig_names)
if not success:
raise KeyError(
- f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function")
+ f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function"
+ )
_check_arg_name_valid(meta_arg_names)
_check_arg_name_valid(concrete_arg_names)
@@ -292,10 +308,9 @@ class ColoTracer(Tracer):
orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction
class PatchedCheckpointFunction(torch.autograd.Function):
-
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
- # signal that the current tracing occurs within activaton checkpoint part
+ # signal that the current tracing occurs within activation checkpoint part
self.inside_torch_checkpoint_func = True
out = run_function(*args)
self.inside_torch_checkpoint_func = False
@@ -305,7 +320,8 @@ class ColoTracer(Tracer):
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError(
- "We do not implement the backward pass as we only trace the forward pass.")
+ "We do not implement the backward pass as we only trace the forward pass."
+ )
# override the checkpoint function
torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction
@@ -356,10 +372,13 @@ class ColoTracer(Tracer):
if attr_val is p:
if n not in parameter_proxy_cache:
kwargs = {}
- if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters:
- kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else
- lambda node: ColoProxy(self, node, n, attr_val))
- val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type]
+ if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
+ kwargs["proxy_factory_fn"] = (
+ None
+ if not self.param_shapes_constant
+ else lambda node: ColoProxy(self, node, n, attr_val)
+ )
+ val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
@@ -370,8 +389,9 @@ class ColoTracer(Tracer):
return maybe_buffer_proxy
if isinstance(attr_val, torch.nn.Parameter):
- maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
- parameter_proxy_cache)
+ maybe_parameter_proxy = maybe_get_proxy_for_attr(
+ attr_val, self.root.named_parameters(), parameter_proxy_cache
+ )
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy
@@ -389,42 +409,41 @@ def symbolic_trace(
if meta_args is not None:
root.to(default_device())
wrap_fn = lambda x: MetaTensor(x, fake_device=default_device()) if isinstance(x, torch.Tensor) else x
- graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root,
- concrete_args=concrete_args,
- meta_args=tree_map(wrap_fn, meta_args))
+ graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(
+ root, concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args)
+ )
root.cpu()
else:
graph = Tracer().trace(root, concrete_args=concrete_args)
else:
from .tracer import ColoTracer as OrigColoTracer
- graph = OrigColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root,
- concrete_args=concrete_args,
- meta_args=meta_args)
+
+ graph = OrigColoTracer(trace_act_ckpt=trace_act_ckpt).trace(
+ root, concrete_args=concrete_args, meta_args=meta_args
+ )
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return ColoGraphModule(root, graph, name)
@compatibility(is_backward_compatible=False)
class _TorchTensorOverride(object):
-
def __init__(self, tracer: Tracer):
self.overrides = {}
self.tracer = tracer
def __enter__(self):
-
def wrap_tensor_method(target):
-
@functools.wraps(target)
def wrapper(*args, **kwargs):
is_proxy = any(isinstance(p, ColoProxy) for p in args) | any(
- isinstance(p, ColoProxy) for p in kwargs.values())
+ isinstance(p, ColoProxy) for p in kwargs.values()
+ )
if is_proxy:
# if the arg is a proxy, then need to record this function called on this proxy
# e.g. torch.ones(size) where size is an input proxy
self.tracer._disable_module_getattr = True
try:
- proxy = self.tracer.create_proxy('call_function', target, args, kwargs)
+ proxy = self.tracer.create_proxy("call_function", target, args, kwargs)
finally:
self.tracer._disable_module_getattr = False
return proxy
@@ -446,11 +465,12 @@ class _TorchTensorOverride(object):
setattr(torch, name, orig)
-def meta_prop_pass(gm: ColoGraphModule,
- root: torch.nn.Module,
- meta_args: Optional[Dict[str, Any]] = None,
- concrete_args: Optional[Dict[str, torch.Tensor]] = None):
-
+def meta_prop_pass(
+ gm: ColoGraphModule,
+ root: torch.nn.Module,
+ meta_args: Optional[Dict[str, Any]] = None,
+ concrete_args: Optional[Dict[str, torch.Tensor]] = None,
+):
if meta_args is None:
meta_args = {}
@@ -465,36 +485,36 @@ def meta_prop_pass(gm: ColoGraphModule,
# update concrete args with default values
non_meta_arg_names = sig_names - meta_arg_names
for k, v in sig.parameters.items():
- if k in non_meta_arg_names and \
- k not in concrete_args and \
- v.default is not inspect.Parameter.empty:
+ if k in non_meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
for node in gm.graph.nodes:
- node._meta_data = _meta_data_computing(meta_args, concrete_args, root, node.op, node.target, node.args,
- node.kwargs)
+ node._meta_data = _meta_data_computing(
+ meta_args, concrete_args, root, node.op, node.target, node.args, node.kwargs
+ )
def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwargs):
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
- if kind == 'placeholder':
+ if kind == "placeholder":
meta_out = meta_args[target] if target in meta_args else concrete_args.get(_truncate_suffix(target), None)
- elif kind == 'get_attr':
+ elif kind == "get_attr":
attr_itr = root
atoms = target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
meta_out = attr_itr
- elif kind == 'call_function':
+ elif kind == "call_function":
meta_out = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
- elif kind == 'call_method':
- if target == '__call__':
+ elif kind == "call_method":
+ if target == "__call__":
meta_out = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
else:
if target not in _TensorPropertyMethod:
- meta_out = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
- **tree_map(unwrap_fn, kwargs))
- elif kind == 'call_module':
+ meta_out = getattr(unwrap_fn(args[0]), target)(
+ *tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)
+ )
+ elif kind == "call_module":
mod = root.get_submodule(target)
meta_out = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
else:
@@ -603,26 +623,30 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar
if kind == "call_function":
if bias_addition_function.has(target):
if target == torch.nn.functional.linear:
- if 'bias' in kwargs and kwargs['bias'] is not None:
+ if "bias" in kwargs and kwargs["bias"] is not None:
function_to_substitute = func_to_func_dict[target]
- handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy,
- function_to_substitute)
+ handle = bias_addition_function.get(target)(
+ tracer, target, args_proxy, kwargs_proxy, function_to_substitute
+ )
else:
function_to_substitute = func_to_func_dict[target]
- handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy,
- function_to_substitute)
+ handle = bias_addition_function.get(target)(
+ tracer, target, args_proxy, kwargs_proxy, function_to_substitute
+ )
elif bias_addition_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
function_to_substitute = func_to_func_dict[target]
- handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy,
- function_to_substitute)
+ handle = bias_addition_function.get(target.__name__)(
+ tracer, target, args_proxy, kwargs_proxy, function_to_substitute
+ )
elif kind == "call_method":
method = getattr(args_metas[0].__class__, target)
if bias_addition_method.has(method):
function_to_substitute = method_to_func_dict[method]
- handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy,
- function_to_substitute)
+ handle = bias_addition_method.get(method)(
+ tracer, target, args_proxy, kwargs_proxy, function_to_substitute
+ )
elif kind == "call_module":
# if not hasattr(self, "orig_forward"):
@@ -631,8 +655,9 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar
mod_type = type(mod)
if bias_addition_module.has(mod_type) and mod.bias is not None:
function_to_substitute = module_to_func_dict[mod_type]
- handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy,
- function_to_substitute)
+ handle = bias_addition_module.get(mod_type)(
+ tracer, target, args_proxy, kwargs_proxy, function_to_substitute
+ )
if handle is not None:
handle.generate()
diff --git a/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py b/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py
index 12c42514895e61777f0dbb1c348206af7d0ac5ec..75d7b18a067c0248f7236190f3b1e72e4cf18ede 100644
--- a/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py
+++ b/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py
@@ -5,4 +5,4 @@ from ...registry import meta_patched_function
@meta_patched_function.register(torch.nn.functional.relu)
def torch_nn_func_relu(input, inplace=False):
- return torch.empty(input.shape, device='meta')
+ return torch.empty(input.shape, device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
index 042b92c5847a4ed0d78d4acf068a0a5030fa7089..3475f22e3b194559d29d1820f9768542e4e7e858 100644
--- a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
+++ b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
@@ -4,7 +4,7 @@ from ...registry import meta_patched_function
@meta_patched_function.register(torch.matmul)
-@meta_patched_function.register('matmul') # for built-in op @
+@meta_patched_function.register("matmul") # for built-in op @
def torch_matmul(input, other, *, out=None):
# copied from huggingface.utils.fx
d1 = input.dim()
@@ -44,8 +44,8 @@ def torch_matmul(input, other, *, out=None):
@meta_patched_function.register(torch.abs)
def torch_abs(input, *, out=None):
- assert out is None, 'out is not supported yet'
- return torch.empty(input.shape, device='meta')
+ assert out is None, "out is not supported yet"
+ return torch.empty(input.shape, device="meta")
@meta_patched_function.register(torch.bmm)
@@ -89,7 +89,7 @@ def torch_addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None):
@meta_patched_function.register(torch.var_mean)
def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None):
- assert out is None, 'saving to out is not supported yet'
- var = torch.empty(1).squeeze(0).to('meta')
- mean = torch.empty(1).squeeze(0).to('meta')
+ assert out is None, "saving to out is not supported yet"
+ var = torch.empty(1).squeeze(0).to("meta")
+ mean = torch.empty(1).squeeze(0).to("meta")
return var, mean
diff --git a/colossalai/fx/tracer/meta_patch/patched_function/convolution.py b/colossalai/fx/tracer/meta_patch/patched_function/convolution.py
index 8500e5c82508195ca3ec8ebc99a33ad2a1b946ad..26daf32a2afc9d3f9367e0e599ce9d34ba7417dc 100644
--- a/colossalai/fx/tracer/meta_patch/patched_function/convolution.py
+++ b/colossalai/fx/tracer/meta_patch/patched_function/convolution.py
@@ -8,7 +8,6 @@ from ...registry import meta_patched_function
def _ntuple(n, name="parse"):
-
def parse(x):
if isinstance(x, collections.abc.Iterable):
return tuple(x)
@@ -24,21 +23,21 @@ _triple = _ntuple(3, "_triple")
def _extract_kwargs(kwargs):
- if 'stride' in kwargs:
- stride = kwargs['stride']
+ if "stride" in kwargs:
+ stride = kwargs["stride"]
else:
stride = 1
# TODO: process str type padding
- if 'padding' in kwargs:
- padding = kwargs['padding']
+ if "padding" in kwargs:
+ padding = kwargs["padding"]
else:
padding = 0
- if 'dilation' in kwargs:
- dilation = kwargs['dilation']
+ if "dilation" in kwargs:
+ dilation = kwargs["dilation"]
else:
dilation = 1
- if 'output_padding' in kwargs:
- output_padding = kwargs['output_padding']
+ if "output_padding" in kwargs:
+ output_padding = kwargs["output_padding"]
else:
output_padding = 0
@@ -61,7 +60,7 @@ def torch_nn_functional_conv1d(input, weight, **kwargs):
c_out,
l_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_function.register(torch.nn.functional.conv2d)
@@ -82,7 +81,7 @@ def torch_nn_functional_conv2d(input, weight, **kwargs):
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_function.register(torch.nn.functional.conv3d)
@@ -105,7 +104,7 @@ def torch_nn_functional_conv3d(input, weight, **kwargs):
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_function.register(torch.nn.functional.conv_transpose1d)
@@ -120,13 +119,14 @@ def torch_nn_functional_convtranspose1d(input, weight, **kwargs):
kernel_size = weight.shape[2:]
l_in = input.shape[-1]
c_out = weight.shape[1]
- l_out = math.floor((l_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) +
- output_padding[0] + 1)
+ l_out = math.floor(
+ (l_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1
+ )
result_shape = input.shape[:-2] + (
c_out,
l_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_function.register(torch.nn.functional.conv_transpose2d)
@@ -141,16 +141,18 @@ def torch_nn_functional_convtranspose2d(input, weight, **kwargs):
kernel_size = weight.shape[2:]
h_in, w_in = input.shape[-2:]
c_out = weight.shape[1]
- h_out = math.floor((h_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) +
- output_padding[0] + 1)
- w_out = math.floor((w_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) +
- output_padding[1] + 1)
+ h_out = math.floor(
+ (h_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1
+ )
+ w_out = math.floor(
+ (w_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + output_padding[1] + 1
+ )
result_shape = input.shape[:-3] + (
c_out,
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_function.register(torch.nn.functional.conv_transpose3d)
@@ -165,16 +167,19 @@ def torch_nn_functional_convtranspose3d(input, weight, **kwargs):
kernel_size = weight.shape[2:]
d_in, h_in, w_in = input.shape[-3:]
c_out = weight.shape[1]
- d_out = math.floor((d_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) +
- output_padding[0] + 1)
- h_out = math.floor((h_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) +
- output_padding[1] + 1)
- w_out = math.floor((w_in - 1) * stride[2] - 2 * padding[2] + dilation[2] * (kernel_size[2] - 1) +
- output_padding[2] + 1)
+ d_out = math.floor(
+ (d_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1
+ )
+ h_out = math.floor(
+ (h_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + output_padding[1] + 1
+ )
+ w_out = math.floor(
+ (w_in - 1) * stride[2] - 2 * padding[2] + dilation[2] * (kernel_size[2] - 1) + output_padding[2] + 1
+ )
result_shape = input.shape[:-4] + (
c_out,
d_out,
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_function/embedding.py b/colossalai/fx/tracer/meta_patch/patched_function/embedding.py
index 6d8d864ea29acd648f7f7097821c248655b0191e..27a79f18590a8e192909fe867b67391941bd44eb 100644
--- a/colossalai/fx/tracer/meta_patch/patched_function/embedding.py
+++ b/colossalai/fx/tracer/meta_patch/patched_function/embedding.py
@@ -4,11 +4,7 @@ from ...registry import meta_patched_function
@meta_patched_function.register(torch.nn.functional.embedding)
-def torch_nn_functional_embedding(input,
- weight,
- padding_idx=None,
- max_norm=None,
- norm_type=2.0,
- scale_grad_by_freq=False,
- sparse=False):
+def torch_nn_functional_embedding(
+ input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False
+):
return torch.empty(*input.shape, weight.shape[-1], device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_function/normalization.py b/colossalai/fx/tracer/meta_patch/patched_function/normalization.py
index e9e7eda6159c88ca3a82320c841360df87540c82..8a62149908306a8ca61dcc40a98bf0cf0b7a6152 100644
--- a/colossalai/fx/tracer/meta_patch/patched_function/normalization.py
+++ b/colossalai/fx/tracer/meta_patch/patched_function/normalization.py
@@ -5,16 +5,11 @@ from ...registry import meta_patched_function
@meta_patched_function.register(torch.nn.functional.layer_norm)
def torch_nn_func_layernorm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
- return torch.empty(input.shape, device='meta')
+ return torch.empty(input.shape, device="meta")
@meta_patched_function.register(torch.nn.functional.batch_norm)
-def torch_nn_func_batchnorm(input,
- running_mean,
- running_var,
- weight=None,
- bias=None,
- training=False,
- momentum=0.1,
- eps=1e-05):
- return torch.empty(input.shape, device='meta')
+def torch_nn_func_batchnorm(
+ input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05
+):
+ return torch.empty(input.shape, device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py
index 4c171cb1099119de54b70b03097e1781880f2624..7642934a409bcafb57370267969b8a384db079b0 100644
--- a/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py
+++ b/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py
@@ -19,9 +19,9 @@ def operator_getitem(a, b):
return t
def _slice_convert(slice_obj):
- attrs = {'start': slice_obj.start, 'stop': slice_obj.stop, 'step': slice_obj.step}
+ attrs = {"start": slice_obj.start, "stop": slice_obj.stop, "step": slice_obj.step}
new_attrs = _slice_attr_convert(attrs)
- attr_dict_to_tuple = (new_attrs['start'], new_attrs['stop'], new_attrs['step'])
+ attr_dict_to_tuple = (new_attrs["start"], new_attrs["stop"], new_attrs["step"])
return slice(*attr_dict_to_tuple)
def _slice_attr_convert(attrs):
diff --git a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py
index b14ff10ce1377055ed7f3ab3025ee7c05c6a1657..c61e1c4dc9e1131a587a79a00eed740737c8f000 100644
--- a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py
+++ b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py
@@ -105,14 +105,15 @@ def torch_cat(tensors, dim=None, axis=None, *, out=None):
shapes = [t.shape for t in tensors]
shape = list(shapes[0])
concatenated_dim = sum(shape[dim] for shape in shapes)
- final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1:]
+ final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1 :]
return torch.empty(final_shape, device="meta")
@meta_patched_function.register(torch.repeat_interleave)
def torch_repeat_interleave(input, repeats, dim=None, output_size=None):
- assert isinstance(repeats, int) or isinstance(repeats, torch.Tensor), \
- "Argument 'repeats' should be of type 'torch.Tensor' or 'int'"
+ assert isinstance(repeats, int) or isinstance(
+ repeats, torch.Tensor
+ ), "Argument 'repeats' should be of type 'torch.Tensor' or 'int'"
shape = list(input.shape) if dim is not None else [input.numel()]
dim = dim if dim is not None else 0
@@ -132,36 +133,36 @@ def torch_tensor_repeat_interleave(self, repeats, dim=None, *, output_size=None)
@meta_patched_function.register(torch.roll)
def torch_roll(input, shifts, dims=None):
- return torch.empty(input.shape, device='meta')
+ return torch.empty(input.shape, device="meta")
@meta_patched_function.register(torch.full)
def torch_full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False):
- assert out is None, 'assigning result to out is not supported yet'
- return torch.empty(size, device='meta', dtype=dtype, layout=layout, requires_grad=requires_grad)
+ assert out is None, "assigning result to out is not supported yet"
+ return torch.empty(size, device="meta", dtype=dtype, layout=layout, requires_grad=requires_grad)
@meta_patched_function.register(torch.max)
def torch_max(input, dim=None, keepdim=False, *, out=None):
- assert out is None, 'assigning value to out is not supported yet'
+ assert out is None, "assigning value to out is not supported yet"
if dim is not None:
if isinstance(dim, int):
shape = list(input.shape)
shape.pop(dim)
if keepdim:
shape.insert(dim, 1)
- return torch.empty(shape, device='meta', dtype=input.dtype), torch.empty(shape,
- device='meta',
- dtype=input.dtype)
+ return torch.empty(shape, device="meta", dtype=input.dtype), torch.empty(
+ shape, device="meta", dtype=input.dtype
+ )
elif isinstance(dim, torch.Tensor):
# when dim is a 0D or 1D tensor, it will maintain the same shape
num_dims = dim.dim()
if num_dims in [0, 1]:
- return torch.empty_like(input, device='meta')
+ return torch.empty_like(input, device="meta")
else:
raise ValueError(f"Expected dim to a 0D or 1D tensor but got {num_dims} dimensions")
else:
- return torch.empty([], device='meta', dtype=input.dtype)
+ return torch.empty([], device="meta", dtype=input.dtype)
@meta_patched_function.register(torch.Tensor.cpu)
diff --git a/colossalai/fx/tracer/meta_patch/patched_module/__init__.py b/colossalai/fx/tracer/meta_patch/patched_module/__init__.py
index e28e52585fffc193473a7c8270c103919cc63e0d..3f40ec2a67ee61963c80f15ce6a271ac70d2772b 100644
--- a/colossalai/fx/tracer/meta_patch/patched_module/__init__.py
+++ b/colossalai/fx/tracer/meta_patch/patched_module/__init__.py
@@ -4,4 +4,4 @@ from .embedding import *
from .linear import *
from .normalization import *
from .pooling import *
-from .rnn import *
\ No newline at end of file
+from .rnn import *
diff --git a/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py b/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py
index d03da6588c1cbf56403dccc5989f4a4987b7e2a9..aa2ede187d37b832bd43324204a098c3176174a6 100644
--- a/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py
+++ b/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py
@@ -10,4 +10,4 @@ from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.ReLU6)
@meta_patched_module.register(torch.nn.PReLU)
def torch_nn_non_linear_act(self, input):
- return torch.empty(input.shape, device='meta')
+ return torch.empty(input.shape, device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_module/convolution.py b/colossalai/fx/tracer/meta_patch/patched_module/convolution.py
index cf9f3487aac9f31ff799e3245132c40d643de64f..35173a68a0be6edefef8bb45b951733e34fc655e 100644
--- a/colossalai/fx/tracer/meta_patch/patched_module/convolution.py
+++ b/colossalai/fx/tracer/meta_patch/patched_module/convolution.py
@@ -11,13 +11,14 @@ def torch_nn_conv1d(self, input):
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d
l_in = input.shape[-1]
c_out = self.out_channels
- l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] *
- (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
+ l_out = math.floor(
+ (l_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
+ )
result_shape = input.shape[:-2] + (
c_out,
l_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.Conv2d)
@@ -26,16 +27,18 @@ def torch_nn_conv2d(self, input):
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv2d
h_in, w_in = input.shape[-2:]
c_out = self.out_channels
- h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] *
- (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
- w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] *
- (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
+ h_out = math.floor(
+ (h_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
+ )
+ w_out = math.floor(
+ (w_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
+ )
result_shape = input.shape[:-3] + (
c_out,
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.Conv3d)
@@ -44,19 +47,22 @@ def torch_nn_conv3d(self, input):
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv3d
d_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels
- d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] *
- (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
- h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] *
- (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
- w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] *
- (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1)
+ d_out = math.floor(
+ (d_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
+ )
+ h_out = math.floor(
+ (h_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
+ )
+ w_out = math.floor(
+ (w_in + 2 * self.padding[2] - self.dilation[2] * (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1
+ )
result_shape = input.shape[:-4] + (
c_out,
d_out,
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.ConvTranspose1d)
@@ -65,13 +71,18 @@ def torch_nn_convtranspose1d(self, input):
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
l_in = input.shape[-1]
c_out = self.out_channels
- l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
- (self.kernel_size[0] - 1) + self.output_padding[0] + 1)
+ l_out = math.floor(
+ (l_in - 1) * self.stride[0]
+ - 2 * self.padding[0]
+ + self.dilation[0] * (self.kernel_size[0] - 1)
+ + self.output_padding[0]
+ + 1
+ )
result_shape = input.shape[:-2] + (
c_out,
l_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.ConvTranspose2d)
@@ -80,16 +91,26 @@ def torch_nn_convtranspose2d(self, input):
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
h_in, w_in = input.shape[-2:]
c_out = self.out_channels
- h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
- (self.kernel_size[0] - 1) + self.output_padding[0] + 1)
- w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
- (self.kernel_size[1] - 1) + self.output_padding[1] + 1)
+ h_out = math.floor(
+ (h_in - 1) * self.stride[0]
+ - 2 * self.padding[0]
+ + self.dilation[0] * (self.kernel_size[0] - 1)
+ + self.output_padding[0]
+ + 1
+ )
+ w_out = math.floor(
+ (w_in - 1) * self.stride[1]
+ - 2 * self.padding[1]
+ + self.dilation[1] * (self.kernel_size[1] - 1)
+ + self.output_padding[1]
+ + 1
+ )
result_shape = input.shape[:-3] + (
c_out,
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.ConvTranspose3d)
@@ -98,16 +119,31 @@ def torch_nn_convtranspose3d(self, input):
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html
d_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels
- d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
- (self.kernel_size[0] - 1) + self.output_padding[0] + 1)
- h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
- (self.kernel_size[1] - 1) + self.output_padding[1] + 1)
- w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] *
- (self.kernel_size[2] - 1) + self.output_padding[2] + 1)
+ d_out = math.floor(
+ (d_in - 1) * self.stride[0]
+ - 2 * self.padding[0]
+ + self.dilation[0] * (self.kernel_size[0] - 1)
+ + self.output_padding[0]
+ + 1
+ )
+ h_out = math.floor(
+ (h_in - 1) * self.stride[1]
+ - 2 * self.padding[1]
+ + self.dilation[1] * (self.kernel_size[1] - 1)
+ + self.output_padding[1]
+ + 1
+ )
+ w_out = math.floor(
+ (w_in - 1) * self.stride[2]
+ - 2 * self.padding[2]
+ + self.dilation[2] * (self.kernel_size[2] - 1)
+ + self.output_padding[2]
+ + 1
+ )
result_shape = input.shape[:-4] + (
c_out,
d_out,
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_module/embedding.py b/colossalai/fx/tracer/meta_patch/patched_module/embedding.py
index 999e33b17c1c7b442d2a6db73f957be4413f1fa1..f28647e9caa576cfae44438a287c57bf54efefe1 100644
--- a/colossalai/fx/tracer/meta_patch/patched_module/embedding.py
+++ b/colossalai/fx/tracer/meta_patch/patched_module/embedding.py
@@ -6,4 +6,4 @@ from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.Embedding)
def torch_nn_embedding(self, input):
result_shape = input.shape + (self.embedding_dim,)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_module/linear.py b/colossalai/fx/tracer/meta_patch/patched_module/linear.py
index 56f13bf97532e26770a0be7a4226ee69a2124ee5..97e6b0e96e831f7b0af3155edbae73ec04603ce0 100644
--- a/colossalai/fx/tracer/meta_patch/patched_module/linear.py
+++ b/colossalai/fx/tracer/meta_patch/patched_module/linear.py
@@ -6,5 +6,7 @@ from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.Linear)
def torch_nn_linear(self, input):
last_dim = input.shape[-1]
- assert last_dim == self.in_features, f'Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch'
+ assert (
+ last_dim == self.in_features
+ ), f"Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch"
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_module/normalization.py b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py
index c21ff64cf3dec9baf357771fa0d15b341b413ac1..198e72e342b13d32610a16e2f71a410c70f395b5 100644
--- a/colossalai/fx/tracer/meta_patch/patched_module/normalization.py
+++ b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py
@@ -23,6 +23,7 @@ def torch_nn_normalize(self, input):
try:
import apex
+
meta_patched_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize)
meta_patched_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize)
meta_patched_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize)
diff --git a/colossalai/fx/tracer/meta_patch/patched_module/pooling.py b/colossalai/fx/tracer/meta_patch/patched_module/pooling.py
index 7ce23fbf7ac9368f4ec8496b252494e779fb5015..450586d02f8f294730412be0ba0718b0146421bf 100644
--- a/colossalai/fx/tracer/meta_patch/patched_module/pooling.py
+++ b/colossalai/fx/tracer/meta_patch/patched_module/pooling.py
@@ -8,7 +8,7 @@ from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.AvgPool1d)
def torch_nn_avgpool1d(self, input):
num_dim = input.dim()
- assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions'
+ assert num_dim in [2, 3], f"expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions"
l_in = input.shape[-1]
@@ -25,13 +25,13 @@ def torch_nn_avgpool1d(self, input):
l_out = math.floor((l_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1)
result_shape = tuple(input.shape[:-1]) + (l_out,)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.AvgPool2d)
def torch_nn_avgpool2d(self, input):
num_dim = input.dim()
- assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions'
+ assert num_dim in [3, 4], f"expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions"
h_in, w_in = input.shape[-2:]
@@ -52,13 +52,13 @@ def torch_nn_avgpool2d(self, input):
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.AvgPool3d)
def torch_nn_avgpool3d(self, input):
num_dim = input.dim()
- assert num_dim in [4, 5], f'expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions'
+ assert num_dim in [4, 5], f"expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions"
d_in, h_in, w_in = input.shape[-3:]
@@ -81,13 +81,13 @@ def torch_nn_avgpool3d(self, input):
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.MaxPool1d)
def torch_nn_maxpool1d(self, input):
num_dim = input.dim()
- assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions'
+ assert num_dim in [2, 3], f"expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions"
l_in = input.shape[-1]
@@ -105,13 +105,13 @@ def torch_nn_maxpool1d(self, input):
l_out = math.floor((l_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
result_shape = tuple(input.shape[:-1]) + (l_out,)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.MaxPool2d)
def torch_nn_maxpool2d(self, input):
num_dim = input.dim()
- assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions'
+ assert num_dim in [3, 4], f"expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions"
h_in, w_in = input.shape[-2:]
@@ -133,13 +133,13 @@ def torch_nn_maxpool2d(self, input):
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.MaxPool3d)
def torch_nn_maxpool3d(self, input):
num_dim = input.dim()
- assert num_dim in [4, 5], f'expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions'
+ assert num_dim in [4, 5], f"expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions"
d_in, h_in, w_in = input.shape[-3:]
@@ -163,7 +163,7 @@ def torch_nn_maxpool3d(self, input):
h_out,
w_out,
)
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.AdaptiveAvgPool1d)
@@ -175,7 +175,7 @@ def torch_nn_adapative_pooling_1d(self, input):
else:
output_size = self.output_size
result_shape = tuple(input.shape[:-1]) + output_size
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.AdaptiveAvgPool2d)
@@ -187,7 +187,7 @@ def torch_nn_adapative_pooling_2d(self, input):
else:
output_size = self.output_size
result_shape = tuple(input.shape[:-2]) + output_size
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.AdaptiveAvgPool3d)
@@ -199,4 +199,4 @@ def torch_nn_adapative_pooling_3d(self, input):
else:
output_size = self.output_size
result_shape = tuple(input.shape[:-3]) + output_size
- return torch.empty(result_shape, device='meta')
+ return torch.empty(result_shape, device="meta")
diff --git a/colossalai/fx/tracer/meta_patch/patched_module/rnn.py b/colossalai/fx/tracer/meta_patch/patched_module/rnn.py
index ee15ca34162e83612eb179e0cff066d9f06faf36..bfb7ed17118604acef1958344635118b87bbf185 100644
--- a/colossalai/fx/tracer/meta_patch/patched_module/rnn.py
+++ b/colossalai/fx/tracer/meta_patch/patched_module/rnn.py
@@ -1,5 +1,3 @@
-from typing import Optional
-
import torch
from ...registry import meta_patched_module
@@ -8,9 +6,11 @@ from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.GRU)
@meta_patched_module.register(torch.nn.RNN)
def torch_nn_rnn(self, input, hx):
- assert input.shape[
- -1] == self.input_size, f'Expected input to have input size {self.input_size} but got {input.shape[-1]} for the torch.nn.RNN patch'
- assert hx.shape[
- -1] == self.hidden_size, f'Expected hx to have hidden size {self.hidden_size} but got {hx.shape[-1]} for the torch.nn.RNN patch'
+ assert (
+ input.shape[-1] == self.input_size
+ ), f"Expected input to have input size {self.input_size} but got {input.shape[-1]} for the torch.nn.RNN patch"
+ assert (
+ hx.shape[-1] == self.hidden_size
+ ), f"Expected hx to have hidden size {self.hidden_size} but got {hx.shape[-1]} for the torch.nn.RNN patch"
d = 2 if self.bidirectional else 1
return torch.empty(input.shape[:-1] + (self.hidden_size * d,), device="meta"), hx
diff --git a/colossalai/fx/tracer/registry.py b/colossalai/fx/tracer/registry.py
index 12fc6de73d4435dea8ec58fa50b93a6070fd6254..80b3868bb4feaadbd96d448ce32870bc7e3dad75 100644
--- a/colossalai/fx/tracer/registry.py
+++ b/colossalai/fx/tracer/registry.py
@@ -1,11 +1,9 @@
class PatchRegistry:
-
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
-
def wrapper(func):
self.store[source] = func
return func
@@ -21,8 +19,8 @@ class PatchRegistry:
return source in self.store
-meta_patched_function = PatchRegistry(name='patched_functions_for_meta_execution')
-meta_patched_module = PatchRegistry(name='patched_modules_for_meta_execution')
-bias_addition_function = PatchRegistry(name='patched_function_for_bias_addition')
-bias_addition_module = PatchRegistry(name='patched_module_for_bias_addition')
-bias_addition_method = PatchRegistry(name='patched_method_for_bias_addition')
+meta_patched_function = PatchRegistry(name="patched_functions_for_meta_execution")
+meta_patched_module = PatchRegistry(name="patched_modules_for_meta_execution")
+bias_addition_function = PatchRegistry(name="patched_function_for_bias_addition")
+bias_addition_module = PatchRegistry(name="patched_module_for_bias_addition")
+bias_addition_method = PatchRegistry(name="patched_method_for_bias_addition")
diff --git a/colossalai/fx/tracer/tracer.py b/colossalai/fx/tracer/tracer.py
index 1ae31f9589756d4552f6c2247e0b71679625832b..d9cb587b5d39165c8cc86231c12a399ab588a533 100644
--- a/colossalai/fx/tracer/tracer.py
+++ b/colossalai/fx/tracer/tracer.py
@@ -29,7 +29,7 @@ from .registry import (
meta_patched_module,
)
-__all__ = ['ColoTracer']
+__all__ = ["ColoTracer"]
class TracerType(enum.Enum):
@@ -92,7 +92,7 @@ class ColoTracer(Tracer):
return proxy
# if graph is traced for auto parallelism module, some extra node will be added during
- # graph construction to deal with the compatability between bias addition and all reduce.
+ # graph construction to deal with the compatibility between bias addition and all reduce.
# if no extra manipulation is applied, we just pass the origin arguments to create_proxy function
# to create node on computation graph
@@ -103,7 +103,7 @@ class ColoTracer(Tracer):
if kind == "call_function":
if bias_addition_function.has(target):
if target == torch.nn.functional.linear:
- if 'bias' in kwargs and kwargs['bias'] is not None:
+ if "bias" in kwargs and kwargs["bias"] is not None:
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
else:
@@ -160,22 +160,27 @@ class ColoTracer(Tracer):
if n not in parameter_proxy_cache:
kwargs = {}
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
- kwargs["proxy_factory_fn"] = (None if not self.param_shapes_constant else
- lambda node: ParameterProxy(self, node, n, attr_val))
- val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
+ kwargs["proxy_factory_fn"] = (
+ None
+ if not self.param_shapes_constant
+ else lambda node: ParameterProxy(self, node, n, attr_val)
+ )
+ val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
if isinstance(attr_val, torch.nn.Parameter):
- maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
- parameter_proxy_cache)
+ maybe_parameter_proxy = maybe_get_proxy_for_attr(
+ attr_val, self.root.named_parameters(), parameter_proxy_cache
+ )
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
- maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(),
- parameter_proxy_cache)
+ maybe_buffer_proxy = maybe_get_proxy_for_attr(
+ attr_val, self.root.named_buffers(), parameter_proxy_cache
+ )
if maybe_buffer_proxy is not None:
return maybe_buffer_proxy
@@ -190,7 +195,7 @@ class ColoTracer(Tracer):
# if a customized or third-party module like apex.normalization.FusedRMSNorm is patched,
# we should treat it as leaf module as well
if meta_patched_module.has(m.__class__) or self.is_leaf_module(m, module_qualified_name):
- return self.create_proxy('call_module', module_qualified_name, args, kwargs)
+ return self.create_proxy("call_module", module_qualified_name, args, kwargs)
else:
return forward(*args, **kwargs)
@@ -208,10 +213,9 @@ class ColoTracer(Tracer):
self.proxy_cls = ColoProxy
self.tracer_type = TracerType.META
else:
- raise ValueError(f"Unrecognised tracer type {tracer_type}")
+ raise ValueError(f"Unrecognized tracer type {tracer_type}")
def _meta_data_computing(self, kind, target, args, kwargs):
-
if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta:
meta_out = self.meta_args[target]
return meta_out
@@ -235,8 +239,9 @@ class ColoTracer(Tracer):
# Therefore, I need to record the nn.parameter.Parameter attribute for the operation
# added by the bias addition manipulation following the get_attr node.
convert_to_parameter = False
- if target in (torch.transpose, torch.reshape) and isinstance(args_metas[0],
- torch.nn.parameter.Parameter):
+ if target in (torch.transpose, torch.reshape) and isinstance(
+ args_metas[0], torch.nn.parameter.Parameter
+ ):
convert_to_parameter = True
# fetch patched function
if meta_patched_function.has(target):
@@ -309,10 +314,12 @@ class ColoTracer(Tracer):
return meta_out
- def trace(self,
- root: nn.Module,
- concrete_args: Optional[Dict[str, Tensor]] = None,
- meta_args: Optional[Dict[str, Tensor]] = None) -> Graph:
+ def trace(
+ self,
+ root: nn.Module,
+ concrete_args: Optional[Dict[str, Tensor]] = None,
+ meta_args: Optional[Dict[str, Tensor]] = None,
+ ) -> Graph:
"""
Trace the forward computation graph using `torch.fx.Tracer`. This tracer enables data-dependent control flow.
@@ -341,9 +348,7 @@ class ColoTracer(Tracer):
# update concrete args with default values
non_meta_arg_names = sig_names - meta_arg_names
for k, v in sig.parameters.items():
- if k in non_meta_arg_names and \
- k not in concrete_args and \
- v.default is not inspect.Parameter.empty:
+ if k in non_meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
# get non concrete arg names
@@ -354,7 +359,8 @@ class ColoTracer(Tracer):
success, element = is_element_in_list(names, sig_names)
if not success:
raise KeyError(
- f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function")
+ f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function"
+ )
_check_arg_name_valid(meta_arg_names)
_check_arg_name_valid(concrete_arg_names)
@@ -363,11 +369,13 @@ class ColoTracer(Tracer):
def _check_kwargs(kwargs, should_be_meta: bool):
for k, v in kwargs.items():
if not should_be_meta:
- assert not torch.is_tensor(v) or not v.is_meta, \
- f'Expected the {k} not to be a meta tensor, please check the args passed to the tracer'
+ assert (
+ not torch.is_tensor(v) or not v.is_meta
+ ), f"Expected the {k} not to be a meta tensor, please check the args passed to the tracer"
else:
- assert v.is_meta == should_be_meta, \
- f'Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer'
+ assert (
+ v.is_meta == should_be_meta
+ ), f"Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer"
_check_kwargs(concrete_args, should_be_meta=False)
_check_kwargs(meta_args, should_be_meta=True)
@@ -442,10 +450,9 @@ class ColoTracer(Tracer):
orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction
class PatchedCheckpointFunction(torch.autograd.Function):
-
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
- # signal that the current tracing occurs within activaton checkpoint part
+ # signal that the current tracing occurs within activation checkpoint part
self.inside_torch_checkpoint_func = True
out = run_function(*args)
self.inside_torch_checkpoint_func = False
@@ -455,7 +462,8 @@ class ColoTracer(Tracer):
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError(
- "We do not implement the backward pass as we only trace the forward pass.")
+ "We do not implement the backward pass as we only trace the forward pass."
+ )
# override the checkpoint function
torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction
@@ -470,12 +478,11 @@ class ColoTracer(Tracer):
if self.inside_torch_checkpoint_func:
# annotate the activation checkpoint module
- node.meta['activation_checkpoint'] = self.act_ckpt_region_count
+ node.meta["activation_checkpoint"] = self.act_ckpt_region_count
return node
def wrap_tensor_constructor_method(target):
-
def look_for_proxy(*args, **kwargs):
# find in pos vars
for arg in args:
@@ -518,12 +525,10 @@ def wrap_tensor_constructor_method(target):
for method in magic_methods:
def _scope(method):
-
def impl(*args, **kwargs):
-
tracer = args[0].tracer
target = getattr(operator, method)
- proxy = tracer.create_proxy('call_function', target, args, kwargs)
+ proxy = tracer.create_proxy("call_function", target, args, kwargs)
if not isinstance(proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs)
proxy = ColoProxy(proxy.node)
@@ -542,7 +547,7 @@ def _define_reflectable(orig_method_name):
def impl(self, rhs):
target = getattr(operator, orig_method_name)
- proxy = self.tracer.create_proxy('call_function', target, (rhs, self), {})
+ proxy = self.tracer.create_proxy("call_function", target, (rhs, self), {})
if not isinstance(proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, *(rhs, self), {})
proxy = ColoProxy(proxy.node)
diff --git a/colossalai/global_variables.py b/colossalai/global_variables.py
deleted file mode 100644
index 61b31965e2e63d2119bfadba0d49478537c31fa7..0000000000000000000000000000000000000000
--- a/colossalai/global_variables.py
+++ /dev/null
@@ -1,56 +0,0 @@
-from typing import Optional
-
-
-class TensorParallelEnv(object):
- _instance = None
-
- def __new__(cls, *args, **kwargs):
- if cls._instance is None:
- cls._instance = object.__new__(cls, *args, **kwargs)
- return cls._instance
-
- def __init__(self, *args, **kwargs):
- self.load(*args, **kwargs)
-
- def load(self,
- mode: Optional[str] = None,
- vocab_parallel: bool = False,
- parallel_input_1d: bool = False,
- summa_dim: int = None,
- tesseract_dim: int = None,
- tesseract_dep: int = None,
- depth_3d: int = None,
- input_group_3d=None,
- weight_group_3d=None,
- output_group_3d=None,
- input_x_weight_group_3d=None,
- output_x_weight_group_3d=None):
- self.mode = mode
- self.vocab_parallel = vocab_parallel
- self.parallel_input_1d = parallel_input_1d
- self.summa_dim = summa_dim
- self.tesseract_dim = tesseract_dim
- self.tesseract_dep = tesseract_dep
- self.depth_3d = depth_3d
- self.input_group_3d = input_group_3d
- self.weight_group_3d = weight_group_3d
- self.output_group_3d = output_group_3d
- self.input_x_weight_group_3d = input_x_weight_group_3d
- self.output_x_weight_group_3d = output_x_weight_group_3d
-
- def save(self):
- return dict(mode=self.mode,
- vocab_parallel=self.vocab_parallel,
- parallel_input_1d=self.parallel_input_1d,
- summa_dim=self.summa_dim,
- tesseract_dim=self.tesseract_dim,
- tesseract_dep=self.tesseract_dep,
- depth_3d=self.depth_3d,
- input_group_3d=self.input_group_3d,
- weight_group_3d=self.weight_group_3d,
- output_group_3d=self.output_group_3d,
- input_x_weight_group_3d=self.input_x_weight_group_3d,
- output_x_weight_group_3d=self.output_x_weight_group_3d)
-
-
-tensor_parallel_env = TensorParallelEnv()
diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..9a965dc982a4c2c24eea5eacc86e85999f205d7a
--- /dev/null
+++ b/colossalai/inference/README.md
@@ -0,0 +1,117 @@
+# 🚀 Colossal-Inference
+
+## Table of contents
+
+## Introduction
+
+`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including TGI, vLLM, FasterTransformer, LightLLM and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users.
+
+## Design
+
+Colossal Inference is composed of two main components:
+
+1. High performance kernels and ops: which are inspired from existing libraries and modified correspondingly.
+2. Efficient memory management mechanism:which includes the key-value cache manager, allowing for zero memory waste during inference.
+ 1. `cache manager`: serves as a memory manager to help manage the key-value cache, it integrates functions such as memory allocation, indexing and release.
+ 2. `batch_infer_info`: holds all essential elements of a batch inference, which is updated every batch.
+3. High-level inference engine combined with `Shardformer`: it allows our inference framework to easily invoke and utilize various parallel methods.
+ 1. `engine.TPInferEngine`: it is a high level interface that integrates with shardformer, especially for multi-card (tensor parallel) inference:
+ 2. `modeling.llama.LlamaInferenceForwards`: contains the `forward` methods for llama inference. (in this case : llama)
+ 3. `policies.llama.LlamaModelInferPolicy` : contains the policies for `llama` models, which is used to call `shardformer` and segmentate the model forward in tensor parallelism way.
+
+## Pipeline of inference:
+
+In this section we discuss how the colossal inference works and integrates with the `Shardformer` . The details can be found in our codes.
+
+
+
+## Roadmap of our implementation
+
+- [x] Design cache manager and batch infer state
+- [x] Design TpInference engine to integrates with `Shardformer`
+- [x] Register corresponding high-performance `kernel` and `ops`
+- [x] Design policies and forwards (e.g. `Llama` and `Bloom`)
+ - [x] policy
+ - [x] context forward
+ - [x] token forward
+- [ ] Replace the kernels with `faster-transformer` in token-forward stage
+- [ ] Support all models
+ - [x] Llama
+ - [x] Bloom
+ - [ ] Chatglm2
+- [ ] Benchmarking for all models
+
+## Get started
+
+### Installation
+
+```bash
+pip install -e .
+```
+
+### Requirements
+
+dependencies
+
+```bash
+pytorch= 1.13.1 (gpu)
+cuda>= 11.6
+transformers= 4.30.2
+triton==2.0.0.dev20221202
+# for install vllm, please use this branch to install https://github.com/tiandiao123/vllm/tree/setup_branch
+vllm
+# for install flash-attention, please use commit hash: 67ae6fd74b4bc99c36b2ce524cf139c35663793c
+flash-attention
+```
+
+### Docker
+
+You can use docker run to use docker container to set-up environment
+
+```
+# env: python==3.8, cuda 11.6, pytorch == 1.13.1 triton==2.0.0.dev20221202, vllm kernels support, flash-attention-2 kernels support
+docker pull hpcaitech/colossalai-inference:v2
+docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash
+
+```
+
+### Dive into fast-inference!
+
+example files are in
+
+```bash
+cd colossalai.examples
+python xx
+```
+
+## Performance
+
+### environment:
+
+We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `colossal-inference` and original `hugging-face torch fp16`.
+
+For various models, experiments were conducted using multiple batch sizes under the consistent model configuration of `7 billion(7b)` parameters, `1024` input length, and 128 output length. The obtained results are as follows (due to time constraints, the evaluation has currently been performed solely on the `A100` single GPU performance; multi-GPU performance will be addressed in the future):
+
+### Single GPU Performance:
+
+Currently the stats below are calculated based on A100 (single GPU), and we calculate token latency based on average values of context-forward and decoding forward process, which means we combine both of processes to calculate token generation times. We are actively developing new features and methods to furthur optimize the performance of LLM models. Please stay tuned.
+
+#### Llama
+
+| batch_size | 8 | 16 | 32 |
+| :---------------------: | :----: | :----: | :----: |
+| hugging-face torch fp16 | 199.12 | 246.56 | 278.4 |
+| colossal-inference | 326.4 | 582.72 | 816.64 |
+
+
+
+### Bloom
+
+| batch_size | 8 | 16 | 32 |
+| :---------------------: | :----: | :----: | :----: |
+| hugging-face torch fp16 | 189.68 | 226.66 | 249.61 |
+| colossal-inference | 323.28 | 538.52 | 611.64 |
+
+
+
+The results of more models are coming soon!
diff --git a/tests/test_layers/test_2p5d/checks_2p5d/__init__.py b/colossalai/inference/__init__.py
similarity index 100%
rename from tests/test_layers/test_2p5d/checks_2p5d/__init__.py
rename to colossalai/inference/__init__.py
diff --git a/colossalai/inference/quant/gptq/__init__.py b/colossalai/inference/quant/gptq/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c035f397923a77d3becfc60d7c554456e545554a
--- /dev/null
+++ b/colossalai/inference/quant/gptq/__init__.py
@@ -0,0 +1,4 @@
+from .cai_gptq import HAS_AUTO_GPTQ
+
+if HAS_AUTO_GPTQ:
+ from .cai_gptq import CaiGPTQLinearOp, CaiQuantLinear
diff --git a/colossalai/inference/quant/gptq/cai_gptq/__init__.py b/colossalai/inference/quant/gptq/cai_gptq/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ed76293bd81ccf1e44ac5044e96335d3efa5dbd
--- /dev/null
+++ b/colossalai/inference/quant/gptq/cai_gptq/__init__.py
@@ -0,0 +1,14 @@
+import warnings
+
+HAS_AUTO_GPTQ = False
+try:
+ import auto_gptq
+
+ HAS_AUTO_GPTQ = True
+except ImportError:
+ warnings.warn("please install auto-gptq from https://github.com/PanQiWei/AutoGPTQ")
+ HAS_AUTO_GPTQ = False
+
+if HAS_AUTO_GPTQ:
+ from .cai_quant_linear import CaiQuantLinear, ColCaiQuantLinear, RowCaiQuantLinear
+ from .gptq_op import CaiGPTQLinearOp
diff --git a/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py b/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca12c34ed958506c769a560d020c3f077f43c935
--- /dev/null
+++ b/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py
@@ -0,0 +1,354 @@
+# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ
+
+import math
+import warnings
+from typing import List, Union
+
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.distributed import ProcessGroup
+
+from colossalai.lazy import LazyInitContext
+from colossalai.shardformer.layer import ParallelModule
+
+from .gptq_op import CaiGPTQLinearOp
+
+HAS_GPTQ_CUDA = False
+try:
+ from colossalai.kernel.op_builder.gptq import GPTQBuilder
+ gptq_cuda = GPTQBuilder().load()
+ HAS_GPTQ_CUDA = True
+except ImportError:
+ warnings.warn('CUDA gptq is not installed')
+ HAS_GPTQ_CUDA = False
+
+
+class CaiQuantLinear(nn.Module):
+
+ def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
+ super().__init__()
+ if bits not in [2, 4, 8]:
+ raise NotImplementedError("Only 2,4,8 bits are supported.")
+ self.infeatures = infeatures
+ self.outfeatures = outfeatures
+ self.bits = bits
+ self.maxq = 2**self.bits - 1
+ self.groupsize = groupsize if groupsize != -1 else infeatures
+
+ self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
+ self.register_buffer(
+ 'qzeros',
+ torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32))
+ self.register_buffer('scales',
+ torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
+ if row_split:
+ self.register_buffer(
+ 'g_idx',
+ torch.tensor([(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)],
+ dtype=torch.int32))
+ else:
+ self.register_buffer('g_idx',
+ torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))
+
+ if bias:
+ self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
+ else:
+ self.bias = None
+
+ self.gptq_linear = CaiGPTQLinearOp(groupsize, bits)
+
+ self.q4 = None
+ self.empty_tensor = torch.empty((1, 1), device="meta")
+ self.tp_size = tp_size
+ self.tp_rank = tp_rank
+ self.row_split = row_split
+
+ def pack(self, linear, scales, zeros, g_idx=None):
+
+ g_idx = g_idx.clone() if g_idx is not None else torch.tensor(
+ [i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)
+
+ scales = scales.t().contiguous()
+ zeros = zeros.t().contiguous()
+ scale_zeros = zeros * scales
+ half_scales = scales.clone().half()
+ # print("scale shape ", scales.shape, scale_zeros.shape, linear.weight.shape)
+ self.scales = scales.clone().half()
+ if linear.bias is not None:
+ self.bias = linear.bias.clone().half()
+
+ wn = 8
+ pbits = 32
+ ptype = torch.int32
+ unsign_type = np.uint32
+ sign_type = np.int32
+
+ intweight = []
+ for idx in range(self.infeatures):
+ intweight.append(
+ torch.round(
+ (linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:,
+ None])
+ intweight = torch.cat(intweight, dim=1)
+ intweight = intweight.t().contiguous()
+ intweight = intweight.numpy().astype(unsign_type)
+ qweight = np.zeros((intweight.shape[0] // pbits * self.bits, intweight.shape[1]), dtype=unsign_type)
+
+ i = 0
+ row = 0
+
+ while row < qweight.shape[0]:
+ if self.bits in [2, 4, 8]:
+ for j in range(i, i + (pbits // self.bits)):
+ qweight[row] |= intweight[j] << (self.bits * (j - i))
+ i += pbits // self.bits
+ row += 1
+ else:
+ raise NotImplementedError("Only 2,4,8 bits are supported.")
+ qweight = qweight.astype(sign_type)
+ qweight1 = torch.from_numpy(qweight)
+ qweight1 = qweight1.contiguous() #.to("cuda")
+ self.qweight.data.copy_(qweight1)
+
+ qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type)
+ zeros -= 1
+ zeros = zeros.numpy().astype(unsign_type)
+ i = 0
+ col = 0
+ while col < qzeros.shape[1]:
+ if self.bits in [2, 4, 8]:
+ for j in range(i, i + (pbits // self.bits)):
+ qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
+ i += pbits // self.bits
+ col += 1
+ else:
+ raise NotImplementedError("Only 2,4,8 bits are supported.")
+ qzeros = qzeros.astype(sign_type)
+ qzeros = torch.from_numpy(qzeros)
+ qzeros = qzeros
+ self.qzeros.data.copy_(qzeros)
+
+ if torch.equal(self.g_idx.to(g_idx.device), g_idx):
+ self.g_idx = None
+ else:
+ self.g_idx = g_idx
+
+ def init_q4(self):
+ assert self.qweight.device.type == "cuda"
+ self.q4_width = self.qweight.shape[1]
+ if self.g_idx is not None:
+ if self.row_split and torch.equal(
+ self.g_idx,
+ torch.tensor(
+ [(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
+ dtype=torch.int32,
+ device=self.g_idx.device)):
+ self.g_idx = None
+ elif torch.equal(
+ self.g_idx,
+ torch.tensor([i // self.groupsize for i in range(self.infeatures)],
+ dtype=torch.int32,
+ device=self.g_idx.device)):
+ self.g_idx = None
+
+ if self.g_idx is not None:
+ g_idx = self.g_idx.to("cpu")
+ else:
+ g_idx = self.empty_tensor
+
+ self.q4 = gptq_cuda.make_q4(self.qweight, self.qzeros, self.scales, g_idx, torch.cuda.current_device())
+ torch.cuda.synchronize()
+
+ def forward(self, x):
+ outshape = x.shape[:-1] + (self.outfeatures,)
+
+ if HAS_GPTQ_CUDA and self.bits == 4:
+
+ if self.q4 is None:
+ self.init_q4()
+
+ x = x.view(-1, x.shape[-1])
+ output = torch.empty((x.shape[0], self.outfeatures), dtype=torch.float16, device=x.device)
+ gptq_cuda.q4_matmul(x.half(), self.q4, output)
+ if self.bias is not None and (not self.row_split or self.tp_size == 1):
+ output.add_(self.bias)
+ else:
+ if self.bias is not None and (not self.row_split or self.tp_size == 1):
+ bias = self.bias
+ else:
+ bias = None
+ output = self.gptq_linear(
+ x,
+ self.qweight,
+ self.scales,
+ self.qzeros,
+ g_idx=self.g_idx,
+ bias=bias,
+ )
+ return output.view(outshape)
+
+
+def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1):
+
+ qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1)
+ qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1)
+ scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1)
+ g_idx = gptq_linear.g_idx
+ if gptq_linear.bias is not None:
+ bias = gptq_linear.bias.split(gptq_linear.out_features // split_num, dim=-1)
+
+ cai_split_out_features = cai_linear.outfeatures // split_num
+ zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num
+
+ for i in range(split_num):
+ cai_linear.qweight[:, i * cai_split_out_features:(i + 1) *
+ cai_split_out_features] = qweights[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
+ cai_split_out_features]
+ cai_linear.qzeros[:, i * zero_split_block:(i + 1) *
+ zero_split_block] = qzeros[i][:, tp_rank * zero_split_block:(tp_rank + 1) * zero_split_block]
+ cai_linear.scales[:, i * cai_split_out_features:(i + 1) *
+ cai_split_out_features] = scales[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
+ cai_split_out_features]
+ if cai_linear.bias is not None:
+ cai_linear.bias[i * cai_split_out_features:(i + 1) *
+ cai_split_out_features] = bias[i][tp_rank * cai_split_out_features:(tp_rank + 1) *
+ cai_split_out_features]
+
+ cai_linear.g_idx.copy_(g_idx)
+
+
+def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
+
+ qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0)
+ qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0)
+ scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0)
+ g_idxs = gptq_linear.g_idx.split(gptq_linear.in_features // split_num, dim=0)
+
+ cai_split_in_features = cai_linear.infeatures // (32 // cai_linear.bits) // split_num
+ zero_split_block = cai_linear.infeatures // cai_linear.groupsize // split_num
+ idx_split_features = cai_linear.infeatures // split_num
+
+ for i in range(split_num):
+ cai_linear.qweight[i * cai_split_in_features:(i + 1) *
+ cai_split_in_features, :] = qweights[i][tp_rank * cai_split_in_features:(tp_rank + 1) *
+ cai_split_in_features, :]
+ cai_linear.qzeros[i * zero_split_block:(i + 1) *
+ zero_split_block, :] = qzeros[i][tp_rank * zero_split_block:(tp_rank + 1) *
+ zero_split_block, :]
+ cai_linear.scales[i * zero_split_block:(i + 1) *
+ zero_split_block, :] = scales[i][tp_rank * zero_split_block:(tp_rank + 1) *
+ zero_split_block, :]
+ cai_linear.g_idx[i * idx_split_features:(i + 1) *
+ idx_split_features] = g_idxs[i][tp_rank * idx_split_features:(tp_rank + 1) *
+ idx_split_features]
+ if cai_linear.bias is not None:
+ cai_linear.bias.copy_(gptq_linear.bias)
+
+
+class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
+
+ def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
+
+ super().__init__(bits,
+ groupsize,
+ infeatures,
+ outfeatures,
+ bias,
+ tp_size=tp_size,
+ tp_rank=tp_rank,
+ row_split=row_split)
+ self.process_group = None
+
+ @staticmethod
+ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
+ **kwargs) -> ParallelModule:
+ LazyInitContext.materialize(module)
+ # get the attributes
+ in_features = module.in_features
+
+ # ensure only one process group is passed
+ if isinstance(process_group, (list, tuple)):
+ assert len(process_group) == 1, \
+ f'Expected only one process group, got {len(process_group)}.'
+ process_group = process_group[0]
+
+ tp_size = dist.get_world_size(process_group)
+ tp_rank = dist.get_rank(process_group)
+
+ if in_features < tp_size:
+ return module
+
+ if in_features % tp_size != 0:
+ raise ValueError(
+ f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
+ linear_1d = RowCaiQuantLinear(module.bits,
+ module.group_size,
+ module.in_features // tp_size,
+ module.out_features,
+ module.bias is not None,
+ tp_size=tp_size,
+ tp_rank=tp_rank,
+ row_split=True)
+ linear_1d.process_group = process_group
+
+ split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
+ return linear_1d
+
+ def forward(self, x):
+ output = super().forward(x)
+ if self.tp_size > 1:
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
+ if self.bias is not None:
+ output.add_(self.bias)
+ return output
+
+
+class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
+
+ def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
+
+ super().__init__(bits,
+ groupsize,
+ infeatures,
+ outfeatures,
+ bias,
+ tp_size=tp_size,
+ tp_rank=tp_rank,
+ row_split=row_split)
+ self.process_group = None
+
+ @staticmethod
+ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
+ **kwargs) -> ParallelModule:
+ LazyInitContext.materialize(module)
+ # get the attributes
+ in_features = module.in_features
+
+ # ensure only one process group is passed
+ if isinstance(process_group, (list, tuple)):
+ assert len(process_group) == 1, \
+ f'Expected only one process group, got {len(process_group)}.'
+ process_group = process_group[0]
+
+ tp_size = dist.get_world_size(process_group)
+ tp_rank = dist.get_rank(process_group)
+
+ if in_features < tp_size:
+ return module
+
+ if in_features % tp_size != 0:
+ raise ValueError(
+ f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
+ linear_1d = ColCaiQuantLinear(module.bits,
+ module.group_size,
+ module.in_features,
+ module.out_features // tp_size,
+ module.bias is not None,
+ tp_size=tp_size,
+ tp_rank=tp_rank)
+ linear_1d.process_group = process_group
+
+ split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
+ return linear_1d
diff --git a/colossalai/inference/quant/gptq/cai_gptq/gptq_op.py b/colossalai/inference/quant/gptq/cai_gptq/gptq_op.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8902eb35cd0737fcda3451d54e6f55a3e313b4b
--- /dev/null
+++ b/colossalai/inference/quant/gptq/cai_gptq/gptq_op.py
@@ -0,0 +1,58 @@
+import torch
+
+from colossalai.kernel.triton import gptq_fused_linear_triton
+
+
+class CaiGPTQLinearOp(torch.nn.Module):
+ def __init__(self, gptq_group_size, gptq_quant_bits):
+ super(CaiGPTQLinearOp, self).__init__()
+ self.group_size = gptq_group_size
+ self.bits = gptq_quant_bits
+ self.maxq = 2**self.bits - 1
+ self.empty_tensor = torch.zeros(4, device=torch.cuda.current_device())
+
+ def forward(
+ self,
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ weight_scales: torch.Tensor,
+ weight_zeros: torch.Tensor,
+ g_idx: torch.Tensor = None,
+ act_type=0,
+ bias: torch.Tensor = None,
+ residual: torch.Tensor = None,
+ qkv_fused=False,
+ ):
+ add_bias = True
+ if bias is None:
+ bias = self.empty_tensor
+ add_bias = False
+
+ add_residual = True
+ if residual is None:
+ residual = self.empty_tensor
+ add_residual = False
+ x = input.view(-1, input.shape[-1])
+
+ out = gptq_fused_linear_triton(
+ x,
+ weight,
+ weight_scales,
+ weight_zeros,
+ bias,
+ residual,
+ self.bits,
+ self.maxq,
+ self.group_size,
+ qkv_fused,
+ add_bias,
+ add_residual,
+ act_type=act_type,
+ g_idx=g_idx,
+ )
+ if qkv_fused:
+ out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1])
+ else:
+ out = out.view(input.shape[0], input.shape[1], weight.shape[-1])
+
+ return out
diff --git a/colossalai/inference/tensor_parallel/__init__.py b/colossalai/inference/tensor_parallel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..112b920ba158144a18a94b5664df7063555a701c
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/__init__.py
@@ -0,0 +1,4 @@
+from .engine import TPInferEngine
+from .kvcache_manager import MemoryManager
+
+__all__ = ["MemoryManager", "TPInferEngine"]
diff --git a/colossalai/inference/tensor_parallel/batch_infer_state.py b/colossalai/inference/tensor_parallel/batch_infer_state.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac185f1b6529376a77901327cf85e2008a95932a
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/batch_infer_state.py
@@ -0,0 +1,56 @@
+# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later
+from dataclasses import dataclass
+
+import torch
+
+from .kvcache_manager import MemoryManager
+
+
+@dataclass
+class BatchInferState:
+ r"""
+ Information to be passed and used for a batch of inputs during
+ a single model forward
+ """
+ batch_size: int
+ max_len_in_batch: int
+
+ cache_manager: MemoryManager = None
+
+ block_loc: torch.Tensor = None
+ start_loc: torch.Tensor = None
+ seq_len: torch.Tensor = None
+ past_key_values_len: int = None
+
+ is_context_stage: bool = False
+ context_mem_index: torch.Tensor = None
+ decode_is_contiguous: bool = None
+ decode_mem_start: int = None
+ decode_mem_end: int = None
+ decode_mem_index: torch.Tensor = None
+ decode_layer_id: int = None
+
+ device: torch.device = torch.device("cuda")
+
+ @property
+ def total_token_num(self):
+ # return self.batch_size * self.max_len_in_batch
+ assert self.seq_len is not None and self.seq_len.size(0) > 0
+ return int(torch.sum(self.seq_len))
+
+ def set_cache_manager(self, manager: MemoryManager):
+ self.cache_manager = manager
+
+ @staticmethod
+ def init_block_loc(
+ b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor
+ ):
+ """in-place update block loc mapping based on the sequence length of the inputs in current bath"""
+ start_index = 0
+ seq_len_numpy = seq_len.cpu().numpy()
+ for i, cur_seq_len in enumerate(seq_len_numpy):
+ b_loc[i, max_len_in_batch - cur_seq_len : max_len_in_batch] = alloc_mem_index[
+ start_index : start_index + cur_seq_len
+ ]
+ start_index += cur_seq_len
+ return
diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5ef37fee420927848cd4034257b619f510c2ecf
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/engine.py
@@ -0,0 +1,380 @@
+from typing import Any, Callable, List, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from transformers import BloomForCausalLM, LlamaForCausalLM
+from transformers.generation import GenerationConfig
+from transformers.generation.stopping_criteria import StoppingCriteriaList
+from transformers.tokenization_utils_base import BatchEncoding
+
+from colossalai.shardformer import ShardConfig, ShardFormer
+from colossalai.shardformer.policies.auto_policy import get_autopolicy
+
+from .batch_infer_state import BatchInferState
+from .kvcache_manager import MemoryManager
+
+DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
+
+_supported_models = [
+ "LlamaForCausalLM",
+ "LlamaModel",
+ "BloomForCausalLM",
+ "ChatGLMModel",
+ "ChatGLMForConditionalGeneration",
+]
+
+
+class TPInferEngine:
+ """Engine class for tensor parallel inference.
+
+ Args:
+ model (Module): original model, e.g. huggingface CausalLM
+ shard_config (ShardConfig): The config for sharding original model
+ max_batch_size (int): maximum batch size
+ max_input_len (int): maximum input length of sequence
+ max_output_len (int): maximum output length of output tokens
+ dtype (torch.dtype): datatype used to init KV cache space
+ device (str): device the KV cache of engine to be initialized on
+
+ Examples:
+ >>> # define model and shard config for your inference
+ >>> model = ...
+ >>> generate_kwargs = ...
+ >>> shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True)
+ >>> infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
+ >>> outputs = infer_engine.generate(input_ids, **generate_kwargs)
+ """
+
+ def __init__(
+ self,
+ model: nn.Module,
+ shard_config: ShardConfig,
+ max_batch_size: int,
+ max_input_len: int,
+ max_output_len: int,
+ dtype: torch.dtype = torch.float16,
+ device: str = "cuda",
+ ) -> None:
+ self.max_batch_size = max_batch_size
+ self.max_input_len = max_input_len
+ self.max_output_len = max_output_len
+ self.max_total_token_num = self.max_batch_size * (self.max_input_len + self.max_output_len)
+
+ # Constraints relatable with specs of devices and model
+ # This may change into an optional arg in the future
+ assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
+ assert self.max_input_len + self.max_output_len <= 4096, "Max length exceeds the constraint"
+
+ self.dtype = dtype
+
+ self.head_dim = model.config.hidden_size // model.config.num_attention_heads
+ self.head_num = model.config.num_attention_heads
+ num_hidden_layers = (
+ model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers
+ )
+ self.layer_num = num_hidden_layers
+ self.multi_query_group_num = (
+ model.config.multi_query_group_num if hasattr(model.config, "multi_query_group_num") else 0
+ )
+
+ self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
+ self.cache_manager = None
+
+ self.max_dq_buffer_size = 1
+ self.max_inner_outer_dim = 1
+ self.gptq_temp_state_buffer = None
+ self.gptq_temp_dq_buffer = None
+ self.bits = -1
+ self.use_act_order = False
+
+ self.shard_config = shard_config
+ self.model = None
+ # optimize the original model by sharding with ShardFormer
+ self._optimize_model(model=model.to(device))
+
+ def _init_manager(self) -> None:
+ assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig"
+ assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
+ self.head_num //= self.tp_size # update sharded number of heads
+ if self.multi_query_group_num:
+ # NOTE the logic of MQA tensor parallelism should be specified.
+ assert (
+ self.multi_query_group_num % self.tp_size == 0
+ ), f"Cannot shard {self.multi_query_group_num} query groups with tp size {self.tp_size}"
+ self.cache_manager = MemoryManager(
+ self.max_total_token_num,
+ self.dtype,
+ self.multi_query_group_num // self.tp_size,
+ self.head_dim,
+ self.layer_num,
+ )
+ else:
+ self.cache_manager = MemoryManager(
+ self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num
+ )
+
+ def _post_init_gptq_buffer(self, model: nn.Module) -> None:
+ from colossalai.inference.quant.gptq.cai_gptq import CaiQuantLinear
+ HAS_GPTQ_CUDA = False
+ try:
+ from colossalai.kernel.op_builder.gptq import GPTQBuilder
+ gptq_cuda = GPTQBuilder().load()
+ HAS_GPTQ_CUDA = True
+ except ImportError:
+ warnings.warn('CUDA gptq is not installed')
+ HAS_GPTQ_CUDA = False
+
+ for name, submodule in model.named_modules():
+ if isinstance(submodule, CaiQuantLinear):
+ self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8)
+
+ if self.use_act_order:
+ self.max_inner_outer_dim = max(self.max_inner_outer_dim, submodule.infeatures,
+ submodule.outfeatures)
+ self.bits = submodule.bits
+ if not (HAS_GPTQ_CUDA and self.bits == 4):
+ return
+
+ max_input_len = 1
+ if self.use_act_order:
+ max_input_len = self.max_input_len
+ # The temp_state buffer is required to reorder X in the act-order case.
+ # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
+ self.gptq_temp_state_buffer = torch.zeros((max_input_len, self.max_inner_outer_dim),
+ dtype=torch.float16,
+ device=torch.cuda.current_device())
+ self.gptq_temp_dq_buffer = torch.zeros((1, self.max_dq_buffer_size),
+ dtype=torch.float16,
+ device=torch.cuda.current_device())
+
+ gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer,
+ self.gptq_temp_dq_buffer)
+ # Using the default from exllama repo here.
+ matmul_recons_thd = 8
+ matmul_fused_remap = False
+ matmul_no_half2 = False
+ gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
+
+ torch.cuda.empty_cache()
+
+ def _optimize_model(self, model: nn.Module) -> None:
+ """
+ Optimize the original model by sharding with ShardFormer.
+ In further generation, use the sharded model instead of original model.
+ """
+ # NOTE we will change to use an inference config later with additional attrs we want
+ assert self.shard_config.inference_only is True
+ shardformer = ShardFormer(shard_config=self.shard_config)
+ self._prepare_with_shard_config(shard_config=self.shard_config)
+ self._shard_model_by(shardformer, model)
+
+ def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig:
+ """Prepare the engine with a given ShardConfig.
+
+ Args:
+ shard_config (ShardConfig): shard config given to specify settings of the engine.
+ If not provided, a default ShardConfig with tp size 1 will be created.
+ """
+ self.tp_size = 1
+ if shard_config is None:
+ shard_config = ShardConfig(
+ tensor_parallel_process_group=None,
+ pipeline_stage_manager=None,
+ enable_tensor_parallelism=False,
+ enable_fused_normalization=False,
+ enable_all_optimization=False,
+ enable_flash_attention=False,
+ enable_jit_fused=False,
+ inference_only=True,
+ )
+ else:
+ shard_config.inference_only = True
+ shard_config.pipeline_stage_manager = None
+ if shard_config.enable_tensor_parallelism:
+ self.tp_size = shard_config.tensor_parallel_size
+ self._init_manager()
+
+ return shard_config
+
+ def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None:
+ """Shard original model by the given ShardFormer and store the sharded model."""
+ assert (
+ self.tp_size == shardformer.shard_config.tensor_parallel_size
+ ), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
+ model_name = model.__class__.__name__
+ assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
+ policy = get_autopolicy(model, inference_only=True)
+ self.model, _ = shardformer.optimize(model, policy)
+
+ if self.shard_config.inference_gptq:
+ self._post_init_gptq_buffer(model)
+
+ self.model = self.model.cuda()
+
+ @property
+ def supported_models(self) -> List[str]:
+ return _supported_models
+
+ def generate(self, input_tokens: Union[BatchEncoding, dict, list, torch.Tensor], **generate_kwargs) -> torch.Tensor:
+ """Generate token sequence.
+
+ Args:
+ input_tokens: could be one of the following types
+ 1. BatchEncoding or dict (e.g. tokenizer batch_encode)
+ 2. list of input token ids (e.g. appended result of tokenizer encode)
+ 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt')
+ Returns:
+ torch.Tensor: The returned sequence is given inputs + generated_tokens.
+ """
+ if isinstance(input_tokens, torch.Tensor):
+ input_tokens = dict(input_ids=input_tokens, attention_mask=torch.ones_like(input_tokens, dtype=torch.bool))
+ for t in input_tokens:
+ if torch.is_tensor(input_tokens[t]):
+ input_tokens[t] = input_tokens[t].cuda()
+ if "max_new_tokens" not in generate_kwargs:
+ generate_kwargs.update(max_new_tokens=self.max_output_len)
+
+ return self._generate_by_set_infer_state(input_tokens, **generate_kwargs)
+
+ def prepare_batch_state(self, inputs) -> BatchInferState:
+ """
+ Create and prepare BatchInferState used for inference during model forwrad,
+ by processing each sequence of the given inputs.
+
+ Args:
+ inputs: should be one of the following types
+ 1. BatchEncoding or dict (e.g. tokenizer batch_encode)
+ 2. list of input token ids (e.g. appended result of tokenizer encode)
+ 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt')
+ NOTE For torch.Tensor inputs representing a batch of inputs, we are unable to retrieve
+ the actual length (e.g. number of tokens) of each input without attention mask
+ Hence, for torch.Tensor with shape [bs, l] where bs > 1, we will assume
+ all the inputs in the batch has the maximum length l
+ Returns:
+ BatchInferState: the states for the current batch during inference
+ """
+ if not isinstance(inputs, (BatchEncoding, dict, list, torch.Tensor)):
+ raise TypeError(f"inputs type {type(inputs)} is not supported in prepare_batch_state")
+
+ input_ids_list = None
+ attention_mask = None
+
+ if isinstance(inputs, (BatchEncoding, dict)):
+ input_ids_list = inputs["input_ids"]
+ attention_mask = inputs["attention_mask"]
+ else:
+ input_ids_list = inputs
+ if isinstance(input_ids_list[0], int): # for a single input
+ input_ids_list = [input_ids_list]
+ attention_mask = [attention_mask] if attention_mask is not None else attention_mask
+
+ batch_size = len(input_ids_list)
+
+ seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
+ seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
+ start_index = 0
+
+ max_len_in_batch = -1
+ if isinstance(inputs, (BatchEncoding, dict)):
+ for i, attn_mask in enumerate(attention_mask):
+ curr_seq_len = len(attn_mask)
+ # if isinstance(attn_mask, torch.Tensor):
+ # curr_seq_len = int(torch.sum(attn_mask))
+ # else:
+ # curr_seq_len = int(sum(attn_mask))
+ seq_lengths[i] = curr_seq_len
+ seq_start_indexes[i] = start_index
+ start_index += curr_seq_len
+ max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
+ else:
+ length = max(len(input_id) for input_id in input_ids_list)
+ for i, input_ids in enumerate(input_ids_list):
+ curr_seq_len = length
+ seq_lengths[i] = curr_seq_len
+ seq_start_indexes[i] = start_index
+ start_index += curr_seq_len
+ max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
+ block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda")
+ batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
+ batch_infer_state.seq_len = seq_lengths.to("cuda")
+ batch_infer_state.start_loc = seq_start_indexes.to("cuda")
+ batch_infer_state.block_loc = block_loc
+ batch_infer_state.decode_layer_id = 0
+ batch_infer_state.past_key_values_len = 0
+ batch_infer_state.is_context_stage = True
+ batch_infer_state.set_cache_manager(self.cache_manager)
+ return batch_infer_state
+
+ @torch.no_grad()
+ def _generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch.Tensor:
+ """
+ Generate output tokens by setting BatchInferState as an attribute to the model and calling model.generate
+
+ Args:
+ inputs: should be one of the following types
+ 1. BatchEncoding or dict (e.g. tokenizer batch_encode)
+ 2. list of input token ids (e.g. appended result of tokenizer encode)
+ 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt')
+ """
+
+ # for testing, always use sharded model
+ assert self.model is not None, "sharded model does not exist"
+
+ batch_infer_state = self.prepare_batch_state(input_tokens)
+ assert batch_infer_state.max_len_in_batch <= self.max_input_len, "max length in batch exceeds limit"
+
+ # set BatchInferState for the current batch as attr to model
+ # NOTE this is not a preferable way to pass BatchInferState during inference
+ # we might want to rewrite generate function (e.g. _generate_by_pass_infer_state)
+ # and pass BatchInferState via model forward
+ model = self.model
+ if isinstance(model, LlamaForCausalLM):
+ model = self.model.model
+ elif isinstance(model, BloomForCausalLM):
+ model = self.model.transformer
+ setattr(model, "infer_state", batch_infer_state)
+
+ outputs = self.model.generate(**input_tokens, **generate_kwargs, early_stopping=False)
+
+ # NOTE In future development, we're going to let the scheduler to handle the cache,
+ # instead of freeing space explicitly at the end of generation
+ self.cache_manager.free_all()
+
+ return outputs
+
+ # TODO might want to implement the func that generates output tokens by passing BatchInferState
+ # as an arg into model.forward.
+ # It requires rewriting model generate and replacing model forward.
+ @torch.no_grad()
+ def _generate_by_pass_infer_state(
+ self,
+ input_tokens,
+ max_out_length: int,
+ generation_config: Optional[GenerationConfig] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
+ **model_kwargs,
+ ) -> torch.Tensor:
+ raise NotImplementedError("generate by passing BatchInferState is not implemented.")
+
+ # might want to use in rewritten generate method: use after model.forward
+ # BatchInferState is created and kept during generation
+ # after each iter of model forward, we should update BatchInferState
+ def _update_batch_state(self, infer_state: Optional[BatchInferState]) -> None:
+ batch_size = infer_state.batch_size
+ device = infer_state.start_loc.device
+ infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device=device)
+ infer_state.seq_len += 1
+
+ # might want to create a sequence pool
+ # add a single request/sequence/input text at a time and record its length
+ # In other words, store the actual length of input tokens representing a single input text
+ # E.g. "Introduce landmarks in Beijing"
+ # => add request
+ # => record token length and other necessary information to be used
+ # => engine hold all these necessary information until `generate` (or other name) is called,
+ # => put information already recorded in batchinferstate and pass it to model forward
+ # => clear records in engine
+ def add_request():
+ raise NotImplementedError()
diff --git a/colossalai/inference/tensor_parallel/kvcache_manager.py b/colossalai/inference/tensor_parallel/kvcache_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..e74a3a491a7bb87b09d8a93d29d67eea0593d749
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/kvcache_manager.py
@@ -0,0 +1,104 @@
+# Adapted from lightllm/common/mem_manager.py
+# of the ModelTC/lightllm GitHub repository
+# https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
+
+import torch
+from transformers.utils import logging
+
+
+class MemoryManager:
+ r"""
+ Manage token block indexes and allocate physical memory for key and value cache
+
+ Args:
+ size: maximum token number used as the size of key and value buffer
+ dtype: data type of cached key and value
+ head_num: number of heads the memory manager is responsible for
+ head_dim: embedded size per head
+ layer_num: the number of layers in the model
+ device: device used to store the key and value cache
+ """
+
+ def __init__(
+ self,
+ size: int,
+ dtype: torch.dtype,
+ head_num: int,
+ head_dim: int,
+ layer_num: int,
+ device: torch.device = torch.device("cuda"),
+ ):
+ self.logger = logging.get_logger(__name__)
+ self.available_size = size
+ self.past_key_values_length = 0
+ self._init_mem_states(size, device)
+ self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num)
+
+ def _init_mem_states(self, size, device):
+ """Initialize tensors used to manage memory states"""
+ self.mem_state = torch.ones((size,), dtype=torch.bool, device=device)
+ self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device)
+ self.indexes = torch.arange(0, size, dtype=torch.long, device=device)
+
+ def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num):
+ """Initialize key buffer and value buffer on specified device"""
+ self.key_buffer = [
+ torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
+ ]
+ self.value_buffer = [
+ torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
+ ]
+
+ @torch.no_grad()
+ def alloc(self, required_size):
+ """allocate space of required_size by providing indexes representing available physical spaces"""
+ if required_size > self.available_size:
+ self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
+ return None
+ torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
+ select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1)
+ select_index = self.indexes[select_index]
+ self.mem_state[select_index] = 0
+ self.available_size -= len(select_index)
+ return select_index
+
+ @torch.no_grad()
+ def alloc_contiguous(self, required_size):
+ """allocate contiguous space of required_size"""
+ if required_size > self.available_size:
+ self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
+ return None
+ torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
+ sum_size = len(self.mem_cum_sum)
+ loc_sums = (
+ self.mem_cum_sum[required_size - 1 :]
+ - self.mem_cum_sum[0 : sum_size - required_size + 1]
+ + self.mem_state[0 : sum_size - required_size + 1]
+ )
+ can_used_loc = self.indexes[0 : sum_size - required_size + 1][loc_sums == required_size]
+ if can_used_loc.shape[0] == 0:
+ self.logger.info(
+ f"No enough contiguous cache: required_size {required_size} " f"left_size {self.available_size}"
+ )
+ return None
+ start_loc = can_used_loc[0]
+ select_index = self.indexes[start_loc : start_loc + required_size]
+ self.mem_state[select_index] = 0
+ self.available_size -= len(select_index)
+ start = start_loc.item()
+ end = start + required_size
+ return select_index, start, end
+
+ @torch.no_grad()
+ def free(self, free_index):
+ """free memory by updating memory states based on given indexes"""
+ self.available_size += free_index.shape[0]
+ self.mem_state[free_index] = 1
+
+ @torch.no_grad()
+ def free_all(self):
+ """free all memory by updating memory states"""
+ self.available_size = len(self.mem_state)
+ self.mem_state[:] = 1
+ self.past_key_values_length = 0
+ self.logger.info("freed all space of memory manager")
diff --git a/colossalai/inference/tensor_parallel/modeling/__init__.py b/colossalai/inference/tensor_parallel/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4662368b17b42098b572777ed4d3b216d15d3983
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/modeling/__init__.py
@@ -0,0 +1,5 @@
+from .bloom import BloomInferenceForwards
+from .chatglm2 import ChatGLM2InferenceForwards
+from .llama import LlamaInferenceForwards
+
+__all__ = ["BloomInferenceForwards", "LlamaInferenceForwards", "ChatGLM2InferenceForwards"]
diff --git a/colossalai/inference/tensor_parallel/modeling/_utils.py b/colossalai/inference/tensor_parallel/modeling/_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e476c313253835232311c81e258c7a169579e189
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/modeling/_utils.py
@@ -0,0 +1,67 @@
+"""
+Utils for model inference
+"""
+import os
+
+import torch
+
+from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
+
+
+def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
+ """
+ This function copies the key and value cache to the memory cache
+ Args:
+ layer_id : id of current layer
+ key_buffer : key cache
+ value_buffer : value cache
+ context_mem_index : index of memory cache in kv cache manager
+ mem_manager : cache manager
+ """
+ copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
+ copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
+
+
+def init_to_get_rotary(self, base=10000, use_elem=False):
+ """
+ This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer
+ Args:
+ self : Model that holds the rotary positional embedding
+ base : calculation arg
+ use_elem : activated when using chatglm-based models
+ """
+ self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
+ if not hasattr(self.config, "rope_scaling"):
+ rope_scaling_factor = 1.0
+ else:
+ rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
+
+ if hasattr(self.config, "max_sequence_length"):
+ max_seq_len = self.config.max_sequence_length
+ elif hasattr(self.config, "max_position_embeddings"):
+ max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
+ else:
+ max_seq_len = 2048 * rope_scaling_factor
+ base = float(base)
+
+ # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
+ ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", None))
+
+ if ntk_alpha is not None:
+ ntk_alpha = float(ntk_alpha)
+ assert ntk_alpha >= 1, "NTK alpha must be greater than or equal to 1"
+ if ntk_alpha > 1:
+ print(f"Note: NTK enabled, alpha set to {ntk_alpha}")
+ max_seq_len *= ntk_alpha
+ base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula
+
+ n_elem = self.config.head_dim_
+ if use_elem:
+ n_elem //= 2
+
+ inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem))
+ t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
+ freqs = torch.outer(t, inv_freq)
+
+ self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
+ self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py
new file mode 100644
index 0000000000000000000000000000000000000000..27a26caabefa96aef4e6ca8a82429fa881f0dd6e
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/modeling/bloom.py
@@ -0,0 +1,540 @@
+import math
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.distributed as dist
+from torch.nn import CrossEntropyLoss
+from torch.nn import functional as F
+from transformers.models.bloom.modeling_bloom import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BloomAttention,
+ BloomBlock,
+ BloomForCausalLM,
+ BloomModel,
+ CausalLMOutputWithCrossAttentions,
+)
+from transformers.utils import logging
+
+from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
+from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd
+
+
+def generate_alibi(n_head, dtype=torch.float16):
+ """
+ This method is adapted from `_generate_alibi` function
+ in `lightllm/models/bloom/layer_weights/transformer_layer_weight.py`
+ of the ModelTC/lightllm GitHub repository.
+ This method is originally the `build_alibi_tensor` function
+ in `transformers/models/bloom/modeling_bloom.py`
+ of the huggingface/transformers GitHub repository.
+ """
+
+ def get_slopes_power_of_2(n):
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
+ return [start * start**i for i in range(n)]
+
+ def get_slopes(n):
+ if math.log2(n).is_integer():
+ return get_slopes_power_of_2(n)
+ else:
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
+ slopes_power_of_2 = get_slopes_power_of_2(closest_power_of_2)
+ slopes_double = get_slopes(2 * closest_power_of_2)
+ slopes_combined = slopes_power_of_2 + slopes_double[0::2][: n - closest_power_of_2]
+ return slopes_combined
+
+ slopes = get_slopes(n_head)
+ return torch.tensor(slopes, dtype=dtype)
+
+
+class BloomInferenceForwards:
+ """
+ This class serves a micro library for bloom inference forwards.
+ We intend to replace the forward methods for BloomForCausalLM, BloomModel, BloomBlock, and BloomAttention,
+ as well as prepare_inputs_for_generation method for BloomForCausalLM.
+ For future improvement, we might want to skip replacing methods for BloomForCausalLM,
+ and call BloomModel.forward iteratively in TpInferEngine
+ """
+
+ @staticmethod
+ def bloom_model_forward(
+ self: BloomModel,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ infer_state: Optional[BatchInferState] = None,
+ **deprecated_arguments,
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
+ logger = logging.get_logger(__name__)
+
+ if deprecated_arguments.pop("position_ids", False) is not False:
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
+ warnings.warn(
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
+ " passing `position_ids`.",
+ FutureWarning,
+ )
+ if len(deprecated_arguments) > 0:
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ # still need to keep past_key_values to fit original forward flow
+ if past_key_values is None:
+ past_key_values = tuple([None] * len(self.h))
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape batch_size x num_heads x N x N
+ # head_mask has shape n_layer x batch x num_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds)
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # NOTE determine if BatchInferState is passed in via arg
+ # if not, get the attr binded to the model
+ # We might wantto remove setattr later
+ if infer_state is None:
+ assert hasattr(self, "infer_state")
+ infer_state = self.infer_state
+
+ # Compute alibi tensor: check build_alibi_tensor documentation
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+ # if self.cache_manager.past_key_values_length > 0:
+ if infer_state.cache_manager.past_key_values_length > 0:
+ # update the past key values length in cache manager,
+ # NOTE use BatchInferState.past_key_values_length instead the one in cache manager
+ past_key_values_length = infer_state.cache_manager.past_key_values_length
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ # infer_state.cache_manager = self.cache_manager
+
+ if use_cache and seq_length != 1:
+ # prefill stage
+ infer_state.is_context_stage = True # set prefill stage, notify attention layer
+ infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
+ BatchInferState.init_block_loc(
+ infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
+ )
+ else:
+ infer_state.is_context_stage = False
+ alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
+ if alloc_mem is not None:
+ infer_state.decode_is_contiguous = True
+ infer_state.decode_mem_index = alloc_mem[0]
+ infer_state.decode_mem_start = alloc_mem[1]
+ infer_state.decode_mem_end = alloc_mem[2]
+ infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
+ else:
+ print(f" *** Encountered allocation non-contiguous")
+ print(
+ f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
+ )
+ infer_state.decode_is_contiguous = False
+ alloc_mem = infer_state.cache_manager.alloc(batch_size)
+ infer_state.decode_mem_index = alloc_mem
+ # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
+ # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
+ infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
+
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
+ else:
+ attention_mask = attention_mask.to(hidden_states.device)
+
+ # NOTE revise: we might want to store a single 1D alibi(length is #heads) in model,
+ # or store to BatchInferState to prevent re-calculating
+ # When we have multiple process group (e.g. dp together with tp), we need to pass the pg to here
+ # alibi = generate_alibi(self.num_heads).contiguous().cuda()
+ tp_size = dist.get_world_size()
+ curr_tp_rank = dist.get_rank()
+ alibi = (
+ generate_alibi(self.num_heads * tp_size)
+ .contiguous()[curr_tp_rank * self.num_heads : (curr_tp_rank + 1) * self.num_heads]
+ .cuda()
+ )
+ causal_mask = self._prepare_attn_mask(
+ attention_mask,
+ input_shape=(batch_size, seq_length),
+ past_key_values_length=past_key_values_length,
+ )
+
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ # NOTE: currently our KV cache manager does not handle this condition
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
+
+ return custom_forward
+
+ outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ alibi,
+ causal_mask,
+ layer_past,
+ head_mask[i],
+ )
+ else:
+ outputs = block(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=causal_mask,
+ head_mask=head_mask[i],
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ alibi=alibi,
+ infer_state=infer_state,
+ )
+
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+
+ # Add last hidden state
+ hidden_states = self.ln_f(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # update indices of kv cache block
+ # NOT READY FOR PRIME TIME
+ # might want to remove this part, instead, better to pass the BatchInferState from model forward,
+ # and update these information in engine.generate after model foward called
+ infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
+ infer_state.seq_len += 1
+ infer_state.decode_layer_id = 0
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=presents, # should always be (None, None, ..., None)
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ @staticmethod
+ def bloom_for_causal_lm_forward(
+ self: BloomForCausalLM,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ infer_state: Optional[BatchInferState] = None,
+ **deprecated_arguments,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+ logging.get_logger(__name__)
+
+ if deprecated_arguments.pop("position_ids", False) is not False:
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
+ warnings.warn(
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
+ " passing `position_ids`.",
+ FutureWarning,
+ )
+ if len(deprecated_arguments) > 0:
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = BloomInferenceForwards.bloom_model_forward(
+ self.transformer,
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ infer_state=infer_state,
+ )
+ hidden_states = transformer_outputs[0]
+
+ lm_logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(lm_logits.device)
+ # Shift so that tokens < n predict n
+ shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ batch_size, seq_length, vocab_size = shift_logits.shape
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(
+ shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
+ )
+
+ if not return_dict:
+ output = (lm_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ @staticmethod
+ def bloom_for_causal_lm_prepare_inputs_for_generation(
+ self: BloomForCausalLM,
+ input_ids: torch.LongTensor,
+ past_key_values: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> dict:
+ # only last token for input_ids if past is not None
+ if past_key_values:
+ input_ids = input_ids[:, -1].unsqueeze(-1)
+
+ # NOTE we won't use past key values here
+ # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
+ # if past_key_values[0][0].shape[0] == input_ids.shape[0]:
+ # past_key_values = self._convert_to_bloom_cache(past_key_values)
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+ @staticmethod
+ def bloom_block_forward(
+ self: BloomBlock,
+ hidden_states: torch.Tensor,
+ alibi: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ infer_state: Optional[BatchInferState] = None,
+ ):
+ # hidden_states: [batch_size, seq_length, hidden_size]
+
+ # Layer norm at the beginning of the transformer layer.
+ layernorm_output = self.input_layernorm(hidden_states)
+
+ # Layer norm post the self attention.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = hidden_states
+
+ # Self attention.
+ attn_outputs = self.self_attention(
+ layernorm_output,
+ residual,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ alibi=alibi,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ infer_state=infer_state,
+ )
+
+ attention_output = attn_outputs[0]
+
+ outputs = attn_outputs[1:]
+
+ layernorm_output = self.post_attention_layernorm(attention_output)
+
+ # Get residual
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = attention_output
+
+ # MLP.
+ output = self.mlp(layernorm_output, residual)
+
+ if use_cache:
+ outputs = (output,) + outputs
+ else:
+ outputs = (output,) + outputs[1:]
+
+ return outputs # hidden_states, present, attentions
+
+ @staticmethod
+ def bloom_attention_forward(
+ self: BloomAttention,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor,
+ alibi: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ infer_state: Optional[BatchInferState] = None,
+ ):
+ fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
+
+ # 3 x [batch_size, seq_length, num_heads, head_dim]
+ (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
+ batch_size, q_length, H, D_HEAD = query_layer.shape
+ k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
+ v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
+
+ mem_manager = infer_state.cache_manager
+ layer_id = infer_state.decode_layer_id
+
+ if layer_id == 0: # once per model.forward
+ infer_state.cache_manager.past_key_values_length += q_length # += 1
+
+ if infer_state.is_context_stage:
+ # context process
+ max_input_len = q_length
+ b_start_loc = infer_state.start_loc
+ b_seq_len = infer_state.seq_len[:batch_size]
+ q = query_layer.reshape(-1, H, D_HEAD)
+
+ copy_kv_cache_to_dest(k, infer_state.context_mem_index, mem_manager.key_buffer[layer_id])
+ copy_kv_cache_to_dest(v, infer_state.context_mem_index, mem_manager.value_buffer[layer_id])
+
+ # output = self.output[:batch_size*q_length, :, :]
+ output = torch.empty_like(q)
+
+ bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi)
+
+ context_layer = output.view(batch_size, q_length, H * D_HEAD)
+ else:
+ # query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
+ # need shape: batch_size, H, D_HEAD (q_length == 1), input q shape : (batch_size, q_length(1), H, D_HEAD)
+ assert q_length == 1, "for non-context process, we only support q_length == 1"
+ q = query_layer.reshape(-1, H, D_HEAD)
+
+ if infer_state.decode_is_contiguous:
+ # if decode is contiguous, then we copy to key cache and value cache in cache manager directly
+ cache_k = infer_state.cache_manager.key_buffer[layer_id][
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
+ cache_v = infer_state.cache_manager.value_buffer[layer_id][
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
+ cache_k.copy_(k)
+ cache_v.copy_(v)
+ else:
+ # if decode is not contiguous, use triton kernel to copy key and value cache
+ # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head]
+ copy_kv_cache_to_dest(k, infer_state.decode_mem_index, mem_manager.key_buffer[layer_id])
+ copy_kv_cache_to_dest(v, infer_state.decode_mem_index, mem_manager.value_buffer[layer_id])
+
+ b_start_loc = infer_state.start_loc
+ b_loc = infer_state.block_loc
+ b_seq_len = infer_state.seq_len
+ output = torch.empty_like(q)
+ token_attention_fwd(
+ q,
+ mem_manager.key_buffer[layer_id],
+ mem_manager.value_buffer[layer_id],
+ output,
+ b_loc,
+ b_start_loc,
+ b_seq_len,
+ infer_state.cache_manager.past_key_values_length,
+ alibi,
+ )
+
+ context_layer = output.view(batch_size, q_length, H * D_HEAD)
+
+ # update layer id
+ infer_state.decode_layer_id += 1
+
+ # NOTE: always set present as none for now, instead of returning past key value to the next decoding,
+ # we create the past key value pair from the cache manager
+ present = None
+
+ # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
+ if self.pretraining_tp > 1 and self.slow_but_exact:
+ slices = self.hidden_size / self.pretraining_tp
+ output_tensor = torch.zeros_like(context_layer)
+ for i in range(self.pretraining_tp):
+ output_tensor = output_tensor + F.linear(
+ context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
+ self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
+ )
+ else:
+ output_tensor = self.dense(context_layer)
+
+ # dropout is not required here during inference
+ output_tensor = residual + output_tensor
+
+ outputs = (output_tensor, present)
+ assert output_attentions is False, "we do not support output_attentions at this time"
+
+ return outputs
diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b1bc601f436ac2c62079c924f6ac1851482ae10
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py
@@ -0,0 +1,540 @@
+import os
+from typing import Optional, Tuple
+
+import torch
+from torch.nn import CrossEntropyLoss
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+
+from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
+from colossalai.kernel.triton.context_attention import llama2_context_attn_fwd
+from colossalai.kernel.triton.rotary_embedding_kernel import Llama2Forwards
+from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
+from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
+ ChatGLMForConditionalGeneration,
+ ChatGLMModel,
+ GLMBlock,
+ GLMTransformer,
+ SelfAttention,
+ split_tensor_along_last_dim,
+)
+
+from ._utils import copy_kv_to_mem_cache
+
+
+# This func is same as Llama model init_to_get_rotary, we should move them into _utils.py
+def _init_to_get_rotary(self, base=10000):
+ self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
+ if not hasattr(self.config, "rope_scaling"):
+ rope_scaling_factor = 1.0
+ else:
+ rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
+ if hasattr(self.config, "max_sequence_length"):
+ max_seq_len = self.config.max_sequence_length
+ elif hasattr(self.config, "max_position_embeddings"):
+ max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
+ else:
+ max_seq_len = 2048 * rope_scaling_factor
+ base = float(base)
+
+ # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
+ try:
+ ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", 1))
+ assert ntk_alpha >= 1
+ if ntk_alpha > 1:
+ print(f"Note: NTK enabled, alpha set to {ntk_alpha}")
+ max_seq_len *= ntk_alpha
+ base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula
+ except:
+ pass
+ n_elem = self.config.head_dim_ // 2
+ inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem))
+ t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
+ freqs = torch.outer(t, inv_freq)
+
+ self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
+ self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
+ return
+
+
+def get_masks(self, input_ids, past_length, padding_mask=None):
+ batch_size, seq_length = input_ids.shape
+ full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
+ full_attention_mask.tril_()
+ if past_length:
+ full_attention_mask = torch.cat(
+ (
+ torch.ones(batch_size, seq_length, past_length, device=input_ids.device),
+ full_attention_mask,
+ ),
+ dim=-1,
+ )
+
+ if padding_mask is not None:
+ full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
+ if not past_length and padding_mask is not None:
+ full_attention_mask -= padding_mask.unsqueeze(-1) - 1
+ full_attention_mask = (full_attention_mask < 0.5).bool()
+ full_attention_mask.unsqueeze_(1)
+ return full_attention_mask
+
+
+class ChatGLM2InferenceForwards:
+ """
+ This class holds forwards for Chatglm2 inference.
+ We intend to replace the forward methods for ChatGLMModel, ChatGLMEecoderLayer, and ChatGLMAttention.
+ """
+
+ @staticmethod
+ def chatglm_for_conditional_generation_forward(
+ self: ChatGLMForConditionalGeneration,
+ input_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ return_last_logit: Optional[bool] = False,
+ ):
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ infer_state = self.infer_state
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ past_key_values_length = 0
+
+ # NOT READY FOR PRIME TIME
+ # dummy but work, revise it
+ past_key_values_length = infer_state.cache_manager.past_key_values_length
+ seq_length_with_past = seq_length + past_key_values_length
+ infer_state.seq_length_with_past = seq_length_with_past
+
+ # prefill stage at first
+ if use_cache and seq_length != 1:
+ infer_state.is_context_stage = True
+ infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
+ infer_state.init_block_loc(
+ infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
+ )
+ else:
+ infer_state.is_context_stage = False
+ alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
+ if alloc_mem is not None:
+ infer_state.decode_is_contiguous = True
+ infer_state.decode_mem_index = alloc_mem[0]
+ infer_state.decode_mem_start = alloc_mem[1]
+ infer_state.decode_mem_end = alloc_mem[2]
+ infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
+ else:
+ print(f" *** Encountered allocation non-contiguous")
+ print(
+ f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
+ )
+ infer_state.decode_is_contiguous = False
+ alloc_mem = infer_state.cache_manager.alloc(batch_size)
+ infer_state.decode_mem_index = alloc_mem
+ # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
+ # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
+ infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
+
+ # related to rotary embedding
+ if infer_state.is_context_stage:
+ infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
+ position_ids.view(-1).shape[0], -1
+ )
+ infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
+ position_ids.view(-1).shape[0], -1
+ )
+ else:
+ seq_len = infer_state.seq_len
+ infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
+ infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
+ infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item()
+
+ transformer_outputs = self.transformer(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ infer_state=infer_state,
+ )
+
+ hidden_states = transformer_outputs[0]
+ if return_last_logit:
+ hidden_states = hidden_states[-1:]
+ lm_logits = self.transformer.output_layer(hidden_states)
+ lm_logits = lm_logits.transpose(0, 1).contiguous()
+
+ loss = None
+ if labels is not None:
+ lm_logits = lm_logits.to(torch.float32)
+
+ # Shift so that tokens < n predict n
+ shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+ lm_logits = lm_logits.to(hidden_states.dtype)
+ loss = loss.to(hidden_states.dtype)
+
+ if not return_dict:
+ output = (lm_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ @staticmethod
+ def chatglm_model_forward(
+ self: ChatGLMModel,
+ input_ids,
+ position_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.BoolTensor] = None,
+ full_attention_mask: Optional[torch.BoolTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ infer_state: BatchInferState = None,
+ ):
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ batch_size, seq_length = input_ids.shape
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embedding(input_ids)
+
+ if self.pre_seq_len is not None:
+ if past_key_values is None:
+ past_key_values = self.get_prompt(
+ batch_size=batch_size,
+ device=input_ids.device,
+ dtype=inputs_embeds.dtype,
+ )
+ if attention_mask is not None:
+ attention_mask = torch.cat(
+ [
+ attention_mask.new_ones((batch_size, self.pre_seq_len)),
+ attention_mask,
+ ],
+ dim=-1,
+ )
+ if full_attention_mask is None:
+ if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
+ full_attention_mask = get_masks(
+ self, input_ids, infer_state.cache_manager.past_key_values_length, padding_mask=attention_mask
+ )
+
+ # Run encoder.
+ hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
+ inputs_embeds,
+ full_attention_mask,
+ kv_caches=past_key_values,
+ use_cache=use_cache,
+ output_hidden_states=output_hidden_states,
+ infer_state=infer_state,
+ )
+
+ # update indices
+ # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
+ infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
+ infer_state.seq_len += 1
+ infer_state.max_len_in_batch += 1
+ infer_state.cache_manager.past_key_values_length += seq_length
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ presents,
+ all_hidden_states,
+ all_self_attentions,
+ ]
+ if v is not None
+ )
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ @staticmethod
+ def chatglm_encoder_forward(
+ self: GLMTransformer,
+ hidden_states,
+ attention_mask,
+ kv_caches=None,
+ use_cache: Optional[bool] = True,
+ output_hidden_states: Optional[bool] = False,
+ infer_state: Optional[BatchInferState] = None,
+ ):
+ hidden_states = hidden_states.transpose(0, 1).contiguous()
+ if not kv_caches:
+ kv_caches = [None for _ in range(self.num_layers)]
+ presents = () if use_cache else None
+ all_self_attentions = None
+ all_hidden_states = () if output_hidden_states else None
+
+ infer_state.decode_layer_id = 0
+ for index in range(self.num_layers):
+ layer = self.layers[index]
+
+ layer_ret = layer(
+ hidden_states,
+ attention_mask,
+ kv_cache=kv_caches[index],
+ use_cache=use_cache,
+ infer_state=infer_state,
+ )
+
+ infer_state.decode_layer_id += 1
+
+ hidden_states, kv_cache = layer_ret
+ if use_cache:
+ presents = presents + (kv_cache,)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # Final layer norm.
+ hidden_states = hidden_states.transpose(0, 1).contiguous()
+
+ if self.post_layer_norm:
+ hidden_states = self.final_layernorm(hidden_states)
+
+ return hidden_states, presents, all_hidden_states, all_self_attentions
+
+ @staticmethod
+ def chatglm_glmblock_forward(
+ self: GLMBlock,
+ hidden_states,
+ attention_mask,
+ kv_cache=None,
+ use_cache=True,
+ infer_state: Optional[BatchInferState] = None,
+ ):
+ # hidden_states: [s, b, h]
+
+ # Layer norm at the beginning of the transformer layer.
+ layernorm_output = self.input_layernorm(hidden_states)
+ # Self attention.
+ attention_output, kv_cache = self.self_attention(
+ layernorm_output,
+ attention_mask,
+ kv_cache=kv_cache,
+ use_cache=use_cache,
+ infer_state=infer_state,
+ )
+ # Residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = hidden_states
+ layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
+ layernorm_input = residual + layernorm_input
+ # Layer norm post the self attention.
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
+ # MLP.
+ mlp_output = self.mlp(layernorm_output)
+
+ # Second residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = layernorm_input
+
+ output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
+ output = residual + output
+ return output, kv_cache
+
+ @staticmethod
+ def chatglm_flash_attn_kvcache_forward(
+ self: SelfAttention,
+ hidden_states,
+ attention_mask,
+ kv_cache=None,
+ use_cache=True,
+ infer_state: Optional[BatchInferState] = None,
+ ):
+ assert use_cache is True, "use_cache should be set to True using this chatglm attention"
+ # hidden_states: original :[sq, b, h] --> this [b, sq, h]
+ batch_size = hidden_states.shape[0]
+ # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
+ mixed_x_layer = self.query_key_value(hidden_states)
+
+ if self.multi_query_attention:
+ (query_layer, key_layer, value_layer) = mixed_x_layer.split(
+ [
+ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
+ ],
+ dim=-1,
+ )
+ query_layer = query_layer.view(
+ query_layer.size()[:-1]
+ + (
+ self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head,
+ )
+ )
+ key_layer = key_layer.view(
+ key_layer.size()[:-1]
+ + (
+ self.num_multi_query_groups_per_partition,
+ self.hidden_size_per_attention_head,
+ )
+ )
+ value_layer = value_layer.view(
+ value_layer.size()[:-1]
+ + (
+ self.num_multi_query_groups_per_partition,
+ self.hidden_size_per_attention_head,
+ )
+ )
+
+ else:
+ new_tensor_shape = mixed_x_layer.size()[:-1] + (
+ self.num_attention_heads_per_partition,
+ 3 * self.hidden_size_per_attention_head,
+ )
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
+ # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
+
+ cos, sin = infer_state.position_cos, infer_state.position_sin
+
+ Llama2Forwards.rotary_emb_fwd(
+ query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin
+ )
+ if self.multi_query_attention:
+ Llama2Forwards.rotary_emb_fwd(
+ key_layer.view(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head),
+ cos,
+ sin,
+ )
+ else:
+ Llama2Forwards.rotary_emb_fwd(
+ key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head),
+ cos,
+ sin,
+ )
+
+ # reshape q k v to [bsz*sql, num_heads, head_dim] 2*1 ,32/2 ,128
+ query_layer = query_layer.reshape(
+ -1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
+ )
+ key_layer = key_layer.reshape(
+ -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head
+ )
+ value_layer = value_layer.reshape(
+ -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head
+ )
+ if infer_state.is_context_stage:
+ # first token generation:
+ # copy key and value calculated in current step to memory manager
+
+ copy_kv_to_mem_cache(
+ infer_state.decode_layer_id,
+ key_layer,
+ value_layer,
+ infer_state.context_mem_index,
+ infer_state.cache_manager,
+ )
+
+ attn_output = torch.empty_like(query_layer.view(-1, self.projection_size))
+
+ # NOTE: no bug in context attn fwd (del it )
+ llama2_context_attn_fwd(
+ query_layer,
+ key_layer,
+ value_layer,
+ attn_output.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head),
+ infer_state.start_loc,
+ infer_state.seq_len,
+ infer_state.seq_length_with_past,
+ )
+
+ else:
+ if infer_state.decode_is_contiguous:
+ # if decode is contiguous, then we copy to key cache and value cache in cache manager directly
+ cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
+ cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
+ cache_k.copy_(key_layer)
+ cache_v.copy_(value_layer)
+ else:
+ # if decode is not contiguous, use triton kernel to copy key and value cache
+ # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
+ copy_kv_to_mem_cache(
+ infer_state.decode_layer_id,
+ key_layer,
+ value_layer,
+ infer_state.decode_mem_index,
+ infer_state.cache_manager,
+ )
+
+ # second token and follows
+ attn_output = torch.empty_like(query_layer.view(-1, self.projection_size))
+ cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
+ : infer_state.decode_mem_end, :, :
+ ]
+ cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
+ : infer_state.decode_mem_end, :, :
+ ]
+
+ # ==================================
+ # core attention computation is replaced by triton kernel
+ # ==================================
+ Llama2TokenAttentionForwards.token_attn(
+ query_layer,
+ cache_k,
+ cache_v,
+ attn_output,
+ infer_state.block_loc,
+ infer_state.start_loc,
+ infer_state.seq_len,
+ infer_state.max_len_in_batch,
+ infer_state.other_kv_index,
+ )
+
+ # print('after attention',torch.isnan(attn_output).any())
+
+ # =================
+ # Output:[b,sq, h]
+ # =================
+
+ output = self.dense(attn_output).reshape(batch_size, -1, self.projection_size)
+ return output, kv_cache
diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7661cee1128d010765cf51a92796e606b2a21d1
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/modeling/llama.py
@@ -0,0 +1,369 @@
+from typing import List, Optional, Tuple
+
+import torch
+from transformers.modeling_outputs import BaseModelOutputWithPast
+from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
+
+from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
+from colossalai.kernel.triton import llama_context_attn_fwd, rotary_embedding_fwd, token_attention_fwd
+
+from ._utils import copy_kv_to_mem_cache
+
+try:
+ from vllm import layernorm_ops, pos_encoding_ops
+
+ rms_norm = layernorm_ops.rms_norm
+ rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
+ HAS_VLLM_KERNERL = True
+except:
+ print("fall back to original rotary_embedding_neox of huggingface")
+ print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
+ print(
+ "if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch"
+ )
+ HAS_VLLM_KERNERL = False
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class LlamaInferenceForwards:
+ """
+ This class holds forwards for llama inference.
+ We intend to replace the forward methods for LlamaModel, LlamaDecoderLayer, and LlamaAttention for LlamaForCausalLM.
+ """
+
+ @staticmethod
+ def llama_model_forward(
+ self: LlamaModel,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ batch_size = input_ids.shape[0] # input_ids.shape[0]
+
+ infer_state = self.infer_state
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ # NOT READY FOR PRIME TIME
+ # dummy but work, revise it
+ past_key_values_length = infer_state.cache_manager.past_key_values_length
+ # past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ # NOTE: differentiate with prefill stage
+ # block_loc require different value-assigning method for two different stage
+ if use_cache and seq_length != 1:
+ # NOTE assume prefill stage
+ # allocate memory block
+ infer_state.is_context_stage = True # set prefill stage, notify attention layer
+ infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
+ infer_state.init_block_loc(
+ infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
+ )
+ else:
+ infer_state.is_context_stage = False
+ alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
+ if alloc_mem is not None:
+ infer_state.decode_is_contiguous = True
+ infer_state.decode_mem_index = alloc_mem[0]
+ infer_state.decode_mem_start = alloc_mem[1]
+ infer_state.decode_mem_end = alloc_mem[2]
+ infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
+ else:
+ print(f" *** Encountered allocation non-contiguous")
+ print(
+ f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
+ )
+ infer_state.decode_is_contiguous = False
+ alloc_mem = infer_state.cache_manager.alloc(batch_size)
+ infer_state.decode_mem_index = alloc_mem
+ # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
+ # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
+ infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ if infer_state.is_context_stage:
+ infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
+ position_ids.view(-1).shape[0], -1
+ )
+ infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
+ position_ids.view(-1).shape[0], -1
+ )
+ else:
+ seq_len = infer_state.seq_len
+ infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
+ infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
+ )
+
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+ )
+
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ infer_state.decode_layer_id = 0
+
+ for idx, decoder_layer in enumerate(self.layers):
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+ # NOTE: modify here for passing args to decoder layer
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ infer_state=infer_state,
+ )
+ infer_state.decode_layer_id += 1
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ hidden_states = self.norm(hidden_states)
+ next_cache = next_decoder_cache if use_cache else None
+
+ # update indices
+ # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
+ infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
+ infer_state.seq_len += 1
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ @staticmethod
+ def llama_decoder_layer_forward(
+ self: LlamaDecoderLayer,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ infer_state: Optional[BatchInferState] = None,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ infer_state=infer_state,
+ )
+
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+ @staticmethod
+ def llama_flash_attn_kvcache_forward(
+ self: LlamaAttention,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ infer_state: Optional[BatchInferState] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ assert use_cache is True, "use_cache should be set to True using this llama attention"
+
+ bsz, q_len, _ = hidden_states.size()
+
+ # NOTE might think about better way to handle transposed k and v
+ # key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head]
+ # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head]
+
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
+
+ # NOTE might want to revise
+ # need some way to record the length of past key values cache
+ # since we won't return past_key_value_cache right now
+ if infer_state.decode_layer_id == 0: # once per model.forward
+ infer_state.cache_manager.past_key_values_length += q_len # seq_len
+
+ cos, sin = infer_state.position_cos, infer_state.position_sin
+ # print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, )
+
+ rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
+ rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin)
+
+ query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
+ key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
+ value_states = value_states.reshape(-1, self.num_heads, self.head_dim)
+
+ if infer_state.is_context_stage:
+ # first token generation
+
+ # copy key and value calculated in current step to memory manager
+ copy_kv_to_mem_cache(
+ infer_state.decode_layer_id,
+ key_states,
+ value_states,
+ infer_state.context_mem_index,
+ infer_state.cache_manager,
+ )
+
+ attn_output = torch.empty_like(query_states)
+
+ llama_context_attn_fwd(
+ query_states,
+ key_states,
+ value_states,
+ attn_output,
+ infer_state.start_loc,
+ infer_state.seq_len,
+ infer_state.cache_manager.past_key_values_length,
+ )
+ else:
+ if infer_state.decode_is_contiguous:
+ # if decode is contiguous, then we copy to key cache and value cache in cache manager directly
+ cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
+ cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
+ cache_k.copy_(key_states)
+ cache_v.copy_(value_states)
+ else:
+ # if decode is not contiguous, use triton kernel to copy key and value cache
+ # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
+ copy_kv_to_mem_cache(
+ infer_state.decode_layer_id,
+ key_states,
+ value_states,
+ infer_state.decode_mem_index,
+ infer_state.cache_manager,
+ )
+
+ # second token and follows
+ # kv = torch.stack((key_states, value_states), dim=2)
+ # (batch_size, seqlen, nheads, headdim)
+ attn_output = torch.empty_like(query_states)
+
+ token_attention_fwd(
+ query_states,
+ infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
+ infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
+ attn_output,
+ infer_state.block_loc,
+ infer_state.start_loc,
+ infer_state.seq_len,
+ infer_state.cache_manager.past_key_values_length,
+ )
+
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ # return past_key_value as None
+ return attn_output, None, None
+
+
+def get_llama_vllm_rmsnorm_forward():
+ if HAS_VLLM_KERNERL:
+
+ def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
+ x = hidden_states
+ out = torch.empty_like(x)
+ rms_norm(
+ out,
+ x,
+ self.weight.data,
+ self.variance_epsilon,
+ )
+
+ return out
+
+ return _vllm_rmsnorm_forward
+ else:
+ return None
diff --git a/colossalai/inference/tensor_parallel/policies/__init__.py b/colossalai/inference/tensor_parallel/policies/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..776c4e85056542a1a35dbbf23650bbc01e5b8ffd
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/policies/__init__.py
@@ -0,0 +1,5 @@
+from .bloom import BloomModelInferPolicy
+from .chatglm2 import ChatGLM2InferPolicy
+from .llama import LlamaModelInferPolicy
+
+__all__ = ["BloomModelInferPolicy", "LlamaModelInferPolicy", "ChatGLM2InferPolicy"]
diff --git a/colossalai/inference/tensor_parallel/policies/bloom.py b/colossalai/inference/tensor_parallel/policies/bloom.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d6df20970005f184dfc96d420b8dfe829e2e176
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/policies/bloom.py
@@ -0,0 +1,99 @@
+from functools import partial
+
+import torch
+from torch.nn import LayerNorm
+
+import colossalai.shardformer.layer as col_nn
+from colossalai.shardformer.modeling.bloom import build_bloom_alibi_tensor_fn
+from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
+from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
+
+from ..modeling.bloom import BloomInferenceForwards
+
+try:
+ from colossalai.kernel.triton import layer_norm
+
+ HAS_TRITON_NORM = True
+except:
+ print("Some of our kernels require triton. You might want to install triton from https://github.com/openai/triton")
+ HAS_TRITON_NORM = False
+
+
+def get_triton_layernorm_forward():
+ if HAS_TRITON_NORM:
+
+ def _triton_layernorm_forward(self: LayerNorm, hidden_states: torch.Tensor):
+ return layer_norm(hidden_states, self.weight.data, self.bias, self.eps)
+
+ return _triton_layernorm_forward
+ else:
+ return None
+
+
+class BloomModelInferPolicy(BloomForCausalLMPolicy):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel
+
+ policy = super().module_policy()
+ if self.shard_config.inference_gptq:
+ from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
+ policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
+ "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
+ "self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
+ "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
+ },
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="self_attention.query_key_value",
+ target_module=ColCaiQuantLinear,
+ kwargs={'split_num': 3}),
+ SubModuleReplacementDescription(
+ suffix="self_attention.dense",
+ target_module=RowCaiQuantLinear,
+ kwargs={'split_num': 1}),
+ SubModuleReplacementDescription(
+ suffix="self_attention.attention_dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.dense_h_to_4h",
+ target_module=ColCaiQuantLinear,
+ kwargs={'split_num': 1}),
+ SubModuleReplacementDescription(
+ suffix="mlp.dense_4h_to_h",
+ target_module=RowCaiQuantLinear,
+ kwargs={'split_num': 1}),
+ ])
+ # NOTE set inference mode to shard config
+ self.shard_config._infer()
+
+ method_replacement = {
+ "forward": BloomInferenceForwards.bloom_for_causal_lm_forward,
+ "prepare_inputs_for_generation": BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation,
+ }
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=BloomForCausalLM
+ )
+
+ method_replacement = {"forward": BloomInferenceForwards.bloom_model_forward}
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomModel)
+
+ method_replacement = {"forward": BloomInferenceForwards.bloom_block_forward}
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomBlock)
+
+ method_replacement = {"forward": BloomInferenceForwards.bloom_attention_forward}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=BloomAttention
+ )
+
+ if HAS_TRITON_NORM:
+ infer_method = get_triton_layernorm_forward()
+ method_replacement = {"forward": partial(infer_method)}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=LayerNorm
+ )
+
+ return policy
diff --git a/colossalai/inference/tensor_parallel/policies/chatglm2.py b/colossalai/inference/tensor_parallel/policies/chatglm2.py
new file mode 100644
index 0000000000000000000000000000000000000000..90f8b4fd2d7eb8fdb6a3fd52c275bffea5c3ab72
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/policies/chatglm2.py
@@ -0,0 +1,74 @@
+from functools import partial
+
+from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
+ ChatGLMForConditionalGeneration,
+ ChatGLMModel,
+ GLMBlock,
+ GLMTransformer,
+ SelfAttention,
+)
+
+# import colossalai
+from colossalai.shardformer.policies.chatglm2 import ChatGLMModelPolicy
+
+from ..modeling._utils import init_to_get_rotary
+from ..modeling.chatglm2 import ChatGLM2InferenceForwards
+
+try:
+ HAS_TRITON_RMSNORM = True
+except:
+ print("you should install triton from https://github.com/openai/triton")
+ HAS_TRITON_RMSNORM = False
+
+
+class ChatGLM2InferPolicy(ChatGLMModelPolicy):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ policy = super().module_policy()
+ self.shard_config._infer()
+
+ model_infer_forward = ChatGLM2InferenceForwards.chatglm_model_forward
+ method_replacement = {"forward": model_infer_forward}
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=ChatGLMModel)
+
+ encoder_infer_forward = ChatGLM2InferenceForwards.chatglm_encoder_forward
+ method_replacement = {"forward": encoder_infer_forward}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=GLMTransformer
+ )
+
+ encoder_layer_infer_forward = ChatGLM2InferenceForwards.chatglm_glmblock_forward
+ method_replacement = {"forward": encoder_layer_infer_forward}
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=GLMBlock)
+
+ attn_infer_forward = ChatGLM2InferenceForwards.chatglm_flash_attn_kvcache_forward
+ method_replacement = {"forward": attn_infer_forward}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=SelfAttention
+ )
+
+ # for rmsnorm and others, we need to check the shape
+ return policy
+
+ def postprocess(self):
+ init_to_get_rotary(self.model)
+ return self.model
+
+
+class ChatGLM2ForConditionalGenerationInferPolicy(ChatGLM2InferPolicy):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ policy = super().module_policy()
+ model_infer_forward = ChatGLM2InferenceForwards.chatglm_for_conditional_generation_forward
+ method_replacement = {"forward": partial(model_infer_forward)}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=ChatGLMForConditionalGeneration
+ )
+ return policy
+
+ def postprocess(self):
+ return super().postprocess()
diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..507c1203dd6bd108e522560738ae604cd66060e3
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/policies/llama.py
@@ -0,0 +1,124 @@
+from functools import partial
+
+import torch
+from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
+
+from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
+
+# import colossalai
+from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
+
+from ..modeling._utils import init_to_get_rotary
+from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward
+
+try:
+ from colossalai.kernel.triton import rmsnorm_forward
+
+ HAS_TRITON_RMSNORM = True
+except:
+ print("you should install triton from https://github.com/openai/triton")
+ HAS_TRITON_RMSNORM = False
+
+
+def get_triton_rmsnorm_forward():
+ if HAS_TRITON_RMSNORM:
+
+ def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
+ return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
+
+ return _triton_rmsnorm_forward
+ else:
+ return None
+
+
+class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ policy = super().module_policy()
+
+ if self.shard_config.inference_gptq:
+ from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
+
+ decoder_attribute_replacement = {
+ "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
+ "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ }
+ policy[LlamaDecoderLayer] = ModulePolicyDescription(
+ attribute_replacement=decoder_attribute_replacement,
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="self_attn.q_proj",
+ target_module=ColCaiQuantLinear,
+ kwargs={"split_num": 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.k_proj",
+ target_module=ColCaiQuantLinear,
+ kwargs={"split_num": 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.v_proj",
+ target_module=ColCaiQuantLinear,
+ kwargs={"split_num": 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.o_proj",
+ target_module=RowCaiQuantLinear,
+ kwargs={"split_num": 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.gate_proj",
+ target_module=ColCaiQuantLinear,
+ kwargs={"split_num": 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.up_proj",
+ target_module=ColCaiQuantLinear,
+ kwargs={"split_num": 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.down_proj",
+ target_module=RowCaiQuantLinear,
+ kwargs={"split_num": 1},
+ ),
+ ],
+ )
+
+ self.shard_config._infer()
+
+ infer_forward = LlamaInferenceForwards.llama_model_forward
+ method_replacement = {"forward": partial(infer_forward)}
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
+
+ infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward
+ method_replacement = {"forward": partial(infer_forward)}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
+ )
+
+ infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward
+ method_replacement = {"forward": partial(infer_forward)}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=LlamaAttention
+ )
+
+ infer_forward = None
+ if HAS_TRITON_RMSNORM:
+ infer_forward = get_triton_rmsnorm_forward()
+ else:
+ # NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123
+ infer_forward = get_llama_vllm_rmsnorm_forward()
+
+ if infer_forward is not None:
+ method_replacement = {"forward": partial(infer_forward)}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=LlamaRMSNorm
+ )
+
+ return policy
+
+ def postprocess(self):
+ init_to_get_rotary(self.model.model)
+ return self.model
diff --git a/colossalai/initialize.py b/colossalai/initialize.py
index 5d3f3e5530cbb71a6e88cd3726615c1b6b61a164..aac57d34a2c1f1eebeb95802937672c95819a963 100644
--- a/colossalai/initialize.py
+++ b/colossalai/initialize.py
@@ -1,69 +1,30 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
-import argparse
import os
-import pprint
+import warnings
from pathlib import Path
-from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Dict, Union
import torch
-import torch.nn as nn
-from torch.nn.modules.loss import _Loss
-from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.optim.lr_scheduler import _LRScheduler
-from torch.optim.optimizer import Optimizer
-from torch.utils.data import DataLoader
+import torch.distributed as dist
-from colossalai.amp import AMP_TYPE, convert_to_amp
-from colossalai.amp.naive_amp import NaiveAMPModel
-from colossalai.builder.builder import build_gradient_handler
-from colossalai.context import Config, ConfigException, ParallelMode
-from colossalai.context.moe_context import MOE_CONTEXT
-from colossalai.core import global_context as gpc
-from colossalai.engine import Engine
-from colossalai.engine.gradient_accumulation import accumulate_gradient
-from colossalai.engine.schedule import (
- InterleavedPipelineSchedule,
- NonPipelineSchedule,
- PipelineSchedule,
- get_tensor_shape,
-)
+from colossalai.context import Config
from colossalai.logging import get_dist_logger
-from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
-from colossalai.utils import get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param
-from colossalai.utils.moe import sync_moe_model_param
-from colossalai.zero.legacy import ShardedOptimizerV2, convert_to_zero_v2
-from colossalai.zero.legacy.gemini.ophooks import BaseOpHook
-
-
-def get_default_parser():
- """Reads user command line and uses an argument parser to parse the input arguments.
- Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
-
- Returns:
- Namespace: Returns the parser with the default arguments, the user may add customized arguments into this parser.
- """
- parser = argparse.ArgumentParser()
- parser.add_argument('--config', type=str, help='path to the config file')
- parser.add_argument('--host', type=str, help='the master address for distributed training')
- parser.add_argument('--port', type=int, help='the master port for distributed training')
- parser.add_argument('--world_size', type=int, help='world size for distributed training')
- parser.add_argument('--rank', type=int, help='rank for the default process group')
- parser.add_argument('--local_rank', type=int, help='local rank on the node')
- parser.add_argument('--backend', type=str, default='nccl', help='backend for distributed communication')
- return parser
-
-
-def launch(config: Union[str, Path, Config, Dict],
- rank: int,
- world_size: int,
- host: str,
- port: int,
- backend: str = 'nccl',
- local_rank: int = None,
- seed: int = 1024,
- verbose: bool = True):
+from colossalai.utils import set_device, set_seed
+
+
+def launch(
+ config: Union[str, Path, Config, Dict],
+ rank: int,
+ world_size: int,
+ host: str,
+ port: int,
+ backend: str = "nccl",
+ local_rank: int = None,
+ seed: int = 1024,
+ verbose: bool = True,
+):
"""This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input
arguments are not given. Then initialize and set distributed environment by calling global_context's functions.
@@ -83,48 +44,33 @@ def launch(config: Union[str, Path, Config, Dict],
Raises:
Exception: Raise exception when config type is wrong
"""
- gpc.verbose = verbose
-
- # set config
- assert isinstance(config, (Config, str, Path, dict)), \
- f'expected argument config to be Config, str or Path, but got {type(config)}'
- if not isinstance(config, Config) and isinstance(config, dict):
- config = Config(config)
- if isinstance(config, (str, Path)):
- config = Config.from_file(config)
- gpc.load_config(config)
+ if rank == 0:
+ warnings.warn("`config` is deprecated and will be removed soon.")
# init default process group
- gpc.init_global_dist(rank, world_size, backend, host, port)
-
- # init process groups for different parallel modes from config
- gpc.init_parallel_groups()
+ init_method = f"tcp://[{host}]:{port}"
+ dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
# set cuda device
if torch.cuda.is_available():
# if local rank is not given, calculate automatically
- gpc.set_device(local_rank)
+ set_device(local_rank)
- # set the number of processes running on the same node
- gpc.detect_num_processes_on_current_node()
-
- gpc.set_seed(seed)
+ set_seed(seed)
if verbose:
logger = get_dist_logger()
- logger.info(
- f'Distributed environment is initialized, '
- f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
- f'tensor parallel size: {gpc.tensor_parallel_size}',
- ranks=[0])
+ logger.info(f"Distributed environment is initialized, world size: {dist.get_world_size()}", ranks=[0])
-def launch_from_slurm(config: Union[str, Path, Config, Dict],
- host: str,
- port: int,
- backend: str = 'nccl',
- seed: int = 1024,
- verbose: bool = True):
+def launch_from_slurm(
+ config: Union[str, Path, Config, Dict],
+ host: str,
+ port: int,
+ backend: str = "nccl",
+ seed: int = 1024,
+ verbose: bool = True,
+):
"""A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables
set by SLURM
@@ -137,29 +83,33 @@ def launch_from_slurm(config: Union[str, Path, Config, Dict],
verbose (bool, optional): Whether to print logs. Defaults to True.
"""
try:
- rank = int(os.environ['SLURM_PROCID'])
- world_size = int(os.environ['SLURM_NPROCS'])
+ rank = int(os.environ["SLURM_PROCID"])
+ world_size = int(os.environ["SLURM_NPROCS"])
except KeyError as e:
raise RuntimeError(
f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM"
)
- launch(config=config,
- rank=rank,
- world_size=world_size,
- host=host,
- port=port,
- backend=backend,
- seed=seed,
- verbose=verbose)
-
-
-def launch_from_openmpi(config: Union[str, Path, Config, Dict],
- host: str,
- port: int,
- backend: str = 'nccl',
- seed: int = 1024,
- verbose: bool = True):
+ launch(
+ config=config,
+ rank=rank,
+ world_size=world_size,
+ host=host,
+ port=port,
+ backend=backend,
+ seed=seed,
+ verbose=verbose,
+ )
+
+
+def launch_from_openmpi(
+ config: Union[str, Path, Config, Dict],
+ host: str,
+ port: int,
+ backend: str = "nccl",
+ seed: int = 1024,
+ verbose: bool = True,
+):
"""A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables
set by OpenMPI
@@ -172,29 +122,30 @@ def launch_from_openmpi(config: Union[str, Path, Config, Dict],
verbose (bool, optional): Whether to print logs. Defaults to True.
"""
try:
- rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
- local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
- world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
+ rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
+ local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
+ world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
except KeyError as e:
raise RuntimeError(
f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI"
)
- launch(config=config,
- local_rank=local_rank,
- rank=rank,
- world_size=world_size,
- host=host,
- port=port,
- backend=backend,
- seed=seed,
- verbose=verbose)
-
-
-def launch_from_torch(config: Union[str, Path, Config, Dict],
- backend: str = 'nccl',
- seed: int = 1024,
- verbose: bool = True):
+ launch(
+ config=config,
+ local_rank=local_rank,
+ rank=rank,
+ world_size=world_size,
+ host=host,
+ port=port,
+ backend=backend,
+ seed=seed,
+ verbose=verbose,
+ )
+
+
+def launch_from_torch(
+ config: Union[str, Path, Config, Dict], backend: str = "nccl", seed: int = 1024, verbose: bool = True
+):
"""A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
from the environment variables set by PyTorch
@@ -205,266 +156,24 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
verbose (bool, optional): Whether to print logs. Defaults to True.
"""
try:
- rank = int(os.environ['RANK'])
- local_rank = int(os.environ['LOCAL_RANK'])
- world_size = int(os.environ['WORLD_SIZE'])
- host = os.environ['MASTER_ADDR']
- port = int(os.environ['MASTER_PORT'])
+ rank = int(os.environ["RANK"])
+ local_rank = int(os.environ["LOCAL_RANK"])
+ world_size = int(os.environ["WORLD_SIZE"])
+ host = os.environ["MASTER_ADDR"]
+ port = int(os.environ["MASTER_PORT"])
except KeyError as e:
raise RuntimeError(
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
)
- launch(config=config,
- local_rank=local_rank,
- rank=rank,
- world_size=world_size,
- host=host,
- port=port,
- backend=backend,
- seed=seed,
- verbose=verbose)
-
-
-def initialize(model: nn.Module,
- optimizer: Optimizer,
- criterion: Optional[_Loss] = None,
- train_dataloader: Optional[Iterable] = None,
- test_dataloader: Optional[Iterable] = None,
- lr_scheduler: Optional[_LRScheduler] = None,
- ophooks: Optional[List[BaseOpHook]] = None,
- verbose: bool = True) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
- """Core function to wrap the essential training components with our functionality based on the config which is
- loaded into gpc.config.
-
- Args:
- model (:class:`torch.nn.Module` or Callbale): Your model instance or a function to build the model.
- optimizer (:class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`):
- Your optimizer instance.
- criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance.
- train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training.
- test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
- lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional.
- verbose (bool, optional): Whether to print logs.
-
- Returns:
- Tuple (engine, train_dataloader, test_dataloader, lr_scheduler):
- A tuple of ``(engine, train_dataloader, test_dataloader, lr_scheduler)``
- where only ``engine`` could not be None.
- """
- # get logger
- logger = get_dist_logger()
- gpc.verbose = verbose
-
- # get config from gpc
- config = gpc.config
-
- # print config
- if verbose:
- logger.info(
- f"\n========== Your Config ========\n"
- f"{pprint.pformat(gpc.config)}\n"
- f"================================\n",
- ranks=[0])
-
- # cudnn
- cudnn_benchmark = config.get('cudnn_benchmark', False)
- cudnn_deterministic = config.get('cudnn_deterministic', False)
- torch.backends.cudnn.benchmark = cudnn_benchmark
- torch.backends.cudnn.deterministic = cudnn_deterministic
- if verbose:
- logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
-
- # zero
- use_zero = hasattr(gpc.config, 'zero')
- if use_zero:
- zero_cfg = gpc.config.get('zero', None)
- if zero_cfg is not None:
- cfg_ = zero_cfg.copy()
- else:
- cfg_ = {}
- optimizer_config = zero_cfg.get('optimizer_config', None)
- model_config = zero_cfg.get('model_config', None)
- model, optimizer = convert_to_zero_v2(model,
- optimizer,
- model_config=model_config,
- optimizer_config=optimizer_config)
-
- logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0])
- else:
- if isinstance(model, nn.Module):
- # first sync model across dp ranks
- model.to(get_current_device())
- elif isinstance(model, Callable):
- model = model().to(get_current_device())
-
- # optimizer maybe a optimizer_cls
- if isinstance(optimizer, Callable):
- optimizer = optimizer(model.parameters())
- logger.warning("Initializing an non ZeRO model with optimizer class")
-
- if not use_zero:
- if is_using_sequence():
- sync_model_param(model, ParallelMode.SEQUENCE_DP)
- elif MOE_CONTEXT.is_initialized:
- sync_moe_model_param(model)
- elif is_using_ddp():
- sync_model_param(model, ParallelMode.DATA)
- else:
- logger.warning(
- "The parameters of models is not automatically synchronized.\n"
- "Please make sure that all parameters are the same in data parallel group.",
- ranks=[0])
-
- # check amp and zero
- fp16_cfg = gpc.config.get('fp16', None)
-
- if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero:
- raise ConfigException(
- "It is not allowed to set fp16 and zero configuration in your config file at the same time")
-
- # clip grad norm
- clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)
-
- # initialize amp
- amp_mode = None
- if fp16_cfg is not None and fp16_cfg.mode is not None:
- cfg_ = fp16_cfg.copy()
- amp_mode = cfg_.pop('mode')
- if is_using_pp():
- assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently'
- if amp_mode == AMP_TYPE.NAIVE:
- cfg_['clip_grad_norm'] = clip_grad_norm
- model, optimizer, criterion = convert_to_amp(model=model,
- optimizer=optimizer,
- criterion=criterion,
- mode=amp_mode,
- amp_config=cfg_)
-
- # get torch ddp config
- torch_ddp_cfg = gpc.config.get('torch_ddp', dict())
-
- # gradient handler
- gradient_handler_cfg = gpc.config.get('gradient_handler', None)
- if gradient_handler_cfg is None:
- # if gradient handler is not specified in the configuration file,
- # check in the following order
- # 1. if optimizer is ZERO, then use zero grad handler
- # 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
- # 3. if using pipeline and dp size larger than 1, use data parallel grad handler
- if isinstance(optimizer, ShardedOptimizerV2):
- gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
- if verbose:
- logger.info(
- "Training with zero is detected, ZeROGradientHandler is automatically "
- "added even though not specified in the configuration",
- ranks=[0])
- elif is_using_ddp() and MOE_CONTEXT.is_initialized:
- gradient_handler_cfg = [dict(type='MoeGradientHandler')]
- if verbose:
- logger.info(
- "Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
- "added even though not specified in the configuration",
- ranks=[0])
- elif is_using_sequence():
- model = DDP(model,
- process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
- device_ids=[torch.cuda.current_device()],
- **torch_ddp_cfg)
- if verbose:
- logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism',
- ranks=[0])
- elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE:
- model = DDP(model,
- process_group=gpc.get_group(ParallelMode.DATA),
- device_ids=[torch.cuda.current_device()],
- **torch_ddp_cfg)
- if verbose:
- logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0])
- elif is_using_ddp():
- gradient_handler_cfg = [dict(type='DataParallelGradientHandler')]
- if verbose:
- logger.info(
- "Data parallel training is detected when using pipeline parallel, "
- "DataParallelGradientHandler is automatically "
- "added even though not specified in the configuration",
- ranks=[0])
- # add pipeline parallel gradient handler, if pipeline shared module is detected
- for param in model.parameters():
- if getattr(param, 'pipeline_shared_module_pg', None) is not None:
- if gradient_handler_cfg is None:
- gradient_handler_cfg = [dict(type='PipelineSharedModuleGradientHandler')]
- else:
- gradient_handler_cfg.append(dict(type='PipelineSharedModuleGradientHandler'))
- if verbose:
- logger.info(
- "pipeline_shared_module is detected, PipelineSharedModuleGradientHandler is automatically "
- "added even though not specified in the configuration",
- ranks=[0])
- break
- else:
- if not isinstance(gradient_handler_cfg, list):
- raise ConfigException(
- f"expected gradient_handler in the configuration file to be a list but got {type(gradient_handler_cfg)}"
- )
-
- # turn off sync buffer for NaiveAMPModel if using torch DDP and NaiveAMPModel at the same time
- # to avoid duplicated buffer synchronization
- if isinstance(model, DDP) and isinstance(model.module, NaiveAMPModel):
- model.module.sync_buffer = False
-
- # initialize schedule for engine
- if is_using_pp():
- tensor_shape = get_tensor_shape()
- use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks')
- if gpc.is_initialized(ParallelMode.PARALLEL_1D):
- scatter_gather = True
- else:
- scatter_gather = False
- if use_interleaved:
- if isinstance(model, nn.Sequential):
- model = nn.ModuleList([model])
- schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
- gpc.config.model.num_chunks,
- tensor_shape=tensor_shape,
- scatter_gather_tensors=scatter_gather)
- else:
- schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
- tensor_shape=tensor_shape,
- scatter_gather_tensors=scatter_gather)
- else:
- schedule = NonPipelineSchedule()
-
- if gradient_handler_cfg is None:
- gradient_handlers = None
- if verbose and not isinstance(model, DDP):
- logger.warning(
- "No PyTorch DDP or gradient handler is set up, please make sure you do not need "
- "to all-reduce the gradients after a training step.",
- ranks=[0])
- else:
- gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg]
-
- # check if optimizer is ColossalaiOptimizer
- if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizerV2)):
- optimizer = ColossalaiOptimizer(optim=optimizer)
-
- # gradient accumulation
- grad_accum_size = gpc.config.get('gradient_accumulation', None)
- if grad_accum_size is not None:
- optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient(
- model=model,
- optimizer=optimizer,
- dataloader=train_dataloader,
- accumulate_size=grad_accum_size,
- gradient_handlers=gradient_handlers,
- lr_scheduler=lr_scheduler)
- engine = Engine(model=model,
- optimizer=optimizer,
- criterion=criterion,
- gradient_handlers=gradient_handlers,
- clip_grad_norm=clip_grad_norm,
- ophook_list=ophooks,
- schedule=schedule)
-
- return engine, train_dataloader, test_dataloader, lr_scheduler
+ launch(
+ config=config,
+ local_rank=local_rank,
+ rank=rank,
+ world_size=world_size,
+ host=host,
+ port=port,
+ backend=backend,
+ seed=seed,
+ verbose=verbose,
+ )
diff --git a/colossalai/interface/__init__.py b/colossalai/interface/__init__.py
index 8c658e375146ae4f6f31f62a9e913ed16fbb0714..98b21c9c02c170dfce5076acf9ea0f99e03cb6c1 100644
--- a/colossalai/interface/__init__.py
+++ b/colossalai/interface/__init__.py
@@ -1,4 +1,4 @@
-from .model import ModelWrapper
+from .model import AMPModelMixin, ModelWrapper
from .optimizer import OptimizerWrapper
-__all__ = ['OptimizerWrapper', 'ModelWrapper']
+__all__ = ["OptimizerWrapper", "ModelWrapper", "AMPModelMixin"]
diff --git a/colossalai/interface/model.py b/colossalai/interface/model.py
index a067d7671ce7eaaa174aced664ec16461cd03034..58df09b853eeaa1bcba79c1116a820822785d084 100644
--- a/colossalai/interface/model.py
+++ b/colossalai/interface/model.py
@@ -23,3 +23,12 @@ class ModelWrapper(nn.Module):
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
+
+
+class AMPModelMixin:
+ """This mixin class defines the interface for AMP training."""
+
+ def update_master_params(self):
+ """
+ Update the master parameters for AMP training.
+ """
diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py
index dd9acab17584a452ad430094ab7fed4b9e272efb..95d11087bececd9b5bce5f50edb7059a22e13b7b 100644
--- a/colossalai/interface/optimizer.py
+++ b/colossalai/interface/optimizer.py
@@ -1,5 +1,6 @@
from typing import Union
+import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
@@ -21,7 +22,7 @@ class OptimizerWrapper:
params = []
for group in self.param_groups:
- params += group['params']
+ params += group["params"]
return params
@property
@@ -53,6 +54,9 @@ class OptimizerWrapper:
"""
loss.backward(*args, **kwargs)
+ def backward_by_grad(self, tensor: Tensor, grad: Tensor):
+ torch.autograd.backward(tensor, grad)
+
def state_dict(self):
"""
Returns the optimizer state.
@@ -78,12 +82,14 @@ class OptimizerWrapper:
"""
nn.utils.clip_grad_value_(self.parameters, clip_value, *args, **kwargs)
- def clip_grad_by_norm(self,
- max_norm: Union[float, int],
- norm_type: Union[float, int] = 2.0,
- error_if_nonfinite: bool = False,
- *args,
- **kwargs) -> Tensor:
+ def clip_grad_by_norm(
+ self,
+ max_norm: Union[float, int],
+ norm_type: Union[float, int] = 2.0,
+ error_if_nonfinite: bool = False,
+ *args,
+ **kwargs,
+ ) -> Tensor:
"""
Clips gradient norm of an iterable of parameters.
@@ -109,7 +115,8 @@ class OptimizerWrapper:
loss (Tensor): The loss to be scaled.
"""
raise NotImplementedError(
- "The method scale_loss is only available for optimizers with mixed precision training")
+ "The method scale_loss is only available for optimizers with mixed precision training"
+ )
def unscale_grad(self):
"""
@@ -118,4 +125,11 @@ class OptimizerWrapper:
Note: Only available for optimizers with mixed precision training.
"""
raise NotImplementedError(
- "The method unscale_grad is only available for optimizers with mixed precision training")
+ "The method unscale_grad is only available for optimizers with mixed precision training"
+ )
+
+ def unwrap(self):
+ """
+ Unwrap the optimizer for checkpoint saving/loading.
+ """
+ return self.optim
diff --git a/colossalai/interface/pretrained.py b/colossalai/interface/pretrained.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f6bc10cd1323b38ae86f304645a41026333f6f3
--- /dev/null
+++ b/colossalai/interface/pretrained.py
@@ -0,0 +1,16 @@
+from typing import Optional
+
+from torch.nn import Module
+
+__all__ = [
+ "get_pretrained_path",
+ "set_pretrained_path",
+]
+
+
+def get_pretrained_path(model: Module) -> Optional[str]:
+ return getattr(model, "_pretrained", None)
+
+
+def set_pretrained_path(model: Module, path: str) -> None:
+ setattr(model, "_pretrained", path)
diff --git a/colossalai/kernel/cuda_native/__init__.py b/colossalai/kernel/cuda_native/__init__.py
index 1d5a6ce495bdec0fe02475ed0bb3b3b67bb86b3c..f8a974b5fb26dbb55d1952d1339c18bff52c3cb5 100644
--- a/colossalai/kernel/cuda_native/__init__.py
+++ b/colossalai/kernel/cuda_native/__init__.py
@@ -1,5 +1,13 @@
from .layer_norm import MixedFusedLayerNorm as LayerNorm
+from .mha.mha import ColoAttention
from .multihead_attention import MultiHeadAttention
-from .scaled_softmax import FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
+from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
-__all__ = ['LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax']
+__all__ = [
+ "LayerNorm",
+ "MultiHeadAttention",
+ "FusedScaleMaskSoftmax",
+ "ScaledUpperTriangMaskedSoftmax",
+ "ColoAttention",
+ "AttnMaskType",
+]
diff --git a/colossalai/kernel/cuda_native/csrc/compat.h b/colossalai/kernel/cuda_native/csrc/compat.h
index 00066dc95475296168c799904dc595ed435d2b0a..a62beef91a8a55bb75a17ae0194381048340f688 100644
--- a/colossalai/kernel/cuda_native/csrc/compat.h
+++ b/colossalai/kernel/cuda_native/csrc/compat.h
@@ -7,4 +7,4 @@
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
-#endif
\ No newline at end of file
+#endif
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu
new file mode 100644
index 0000000000000000000000000000000000000000..2b1b366b1c02203805990d468feddeaa24a703ef
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu
@@ -0,0 +1,63 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#include "column_remap.cuh"
+#include "util.cuh"
+
+const int SHUF_BLOCKSIZE_X = 256;
+const int SHUF_BLOCKSIZE_Y = 16;
+
+__global__ void column_remap_kernel
+(
+ const half* __restrict__ x,
+ half* __restrict__ x_new,
+ const int x_width,
+ const int x_height,
+ const uint32_t* x_map
+)
+{
+ int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
+ int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y;
+ if (x_column >= x_width) return;
+ //if (x_row >= x_height) return;
+
+ int x_stride = x_width;
+ int x_idx = x_row * x_stride + x_column;
+
+ int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height);
+ int x_idx_end = x_row_end * x_stride + x_column;
+
+ int s_column = x_map[x_column];
+ int s_idx = x_row * x_stride + s_column;
+
+ while (x_idx < x_idx_end)
+ {
+ x_new[x_idx] = x[s_idx];
+ x_idx += x_stride;
+ s_idx += x_stride;
+ }
+}
+
+// Remap columns in x to correspond to sequential group index before matmul
+//
+// perform x -> seq_x such that seq_x @ seq_w == x @ w
+
+void column_remap_cuda
+(
+ const half* x,
+ half* x_new,
+ const int x_height,
+ const int x_width,
+ const uint32_t* x_map
+)
+{
+ dim3 threads(SHUF_BLOCKSIZE_X, 1, 1);
+
+ dim3 blocks
+ (
+ (x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X,
+ (x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y,
+ 1
+ );
+
+ column_remap_kernel<<>>(x, x_new, x_width, x_height, x_map);
+}
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..0364e38c4779717d560a04822c777f9d572dd2a7
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh
@@ -0,0 +1,19 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#ifndef _column_remap_cuh
+#define _column_remap_cuh
+
+#include
+#include
+#include
+
+void column_remap_cuda
+(
+ const half* x,
+ half* x_new,
+ const int x_height,
+ const int x_width,
+ const uint32_t* x_map
+);
+
+#endif
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh b/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..c5258813e147554e033eaf9a80c27dd694a50961
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh
@@ -0,0 +1,58 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#ifndef _cuda_compat_cuh
+#define _cuda_compat_cuh
+
+// atomicAdd for half types, to support CC < 7.x
+
+__device__ __forceinline__ void atomicAdd_half(half* address, half val)
+{
+ unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
+ unsigned int old = *address_as_ui;
+ unsigned int assumed;
+
+ do
+ {
+ assumed = old;
+ __half_raw hsum;
+ hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
+ half tmpres = __hadd(hsum, val);
+ hsum = __half_raw(tmpres);
+ old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
+ old = atomicCAS(address_as_ui, assumed, old);
+ }
+ while (assumed != old);
+}
+
+// atomicAdd for half2 types
+
+__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
+{
+ unsigned int* address_as_ui = (unsigned int*)address;
+ unsigned int old = *address_as_ui;
+ unsigned int assumed;
+ do
+ {
+ assumed = old;
+ half2 old_val = *((half2*)&old);
+ half2 new_val = __hadd2(old_val, val);
+ old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
+ }
+ while (assumed != old);
+}
+
+//
+
+#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
+#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
+
+__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
+
+#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
+__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
+#endif
+
+#endif
+#endif
+
+#endif
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu
new file mode 100644
index 0000000000000000000000000000000000000000..4416027c8387a17726f061519f49cd181843ac1d
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu
@@ -0,0 +1,75 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#define _cuda_buffers_cu
+#include "cuda_buffers.cuh"
+
+CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL};
+// __constant__ half2 q4_table[16][256];
+// half2 q4_table_host[16][256];
+// bool q4_table_init = false;
+
+CudaBuffers::CudaBuffers
+(
+ int _device,
+ int _temp_state_size,
+ half* _temp_state,
+ half* _temp_dq
+) :
+ device(_device),
+ temp_state_size(_temp_state_size),
+ temp_state(_temp_state),
+ temp_dq(_temp_dq)
+{
+ cudaSetDevice(_device);
+
+ cudaStreamCreate(&alt_stream_1);
+ cudaStreamCreate(&alt_stream_2);
+ cudaStreamCreate(&alt_stream_3);
+ cudaEventCreate(&alt_stream_1_done);
+ cudaEventCreate(&alt_stream_2_done);
+ cudaEventCreate(&alt_stream_3_done);
+}
+
+CudaBuffers::~CudaBuffers()
+{
+ cudaStreamDestroy(alt_stream_1);
+ cudaStreamDestroy(alt_stream_2);
+ cudaStreamDestroy(alt_stream_3);
+ cudaEventDestroy(alt_stream_1_done);
+ cudaEventDestroy(alt_stream_2_done);
+ cudaEventDestroy(alt_stream_3_done);
+}
+
+CudaBuffers* get_buffers(const int device_index)
+{
+ return g_buffers[device_index];
+}
+
+void prepare_buffers_cuda
+(
+ int _device,
+ int _temp_state_size,
+ half* _temp_state,
+ half* _temp_dq
+)
+{
+ CudaBuffers* buffers = new CudaBuffers
+ (
+ _device,
+ _temp_state_size,
+ _temp_state,
+ _temp_dq
+ );
+
+ g_buffers[_device] = buffers;
+}
+
+void cleanup_buffers_cuda()
+{
+ for (int i = 0; i < CUDA_MAX_DEVICES; i++)
+ {
+ if (!g_buffers[i]) continue;
+ delete g_buffers[i];
+ g_buffers[i] = NULL;
+ }
+}
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..0bf2057c665cbfad46461d437ef5fe44b02f8f2e
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh
@@ -0,0 +1,55 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#ifndef _cuda_buffers_cuh
+#define _cuda_buffers_cuh
+
+#include
+#include
+#include
+#include
+
+const int CUDA_MAX_DEVICES = 16;
+
+// #ifndef _cuda_buffers_cu
+// extern __constant__ half2 q4_table[16][256];
+// #endif
+
+class CudaBuffers
+{
+public:
+ int device;
+
+ half* temp_state; // [max_hidden_rows * intermediate_size]
+ int temp_state_size;
+ half* temp_dq; // size of largest quant tensor * 8
+
+ cudaStream_t alt_stream_1;
+ cudaStream_t alt_stream_2;
+ cudaStream_t alt_stream_3;
+ cudaEvent_t alt_stream_1_done;
+ cudaEvent_t alt_stream_2_done;
+ cudaEvent_t alt_stream_3_done;
+
+ CudaBuffers
+ (
+ int _device,
+ int _temp_state_size,
+ half* _temp_state,
+ half* _temp_dq
+ );
+ ~CudaBuffers();
+};
+
+CudaBuffers* get_buffers(const int device_index);
+
+void prepare_buffers_cuda
+(
+ int _device,
+ int _temp_state_size,
+ half* _temp_state,
+ half* _temp_dq
+);
+
+void cleanup_buffers_cuda();
+
+#endif
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh b/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..5cd2e8553ef6bee59162956dfed4a9a26227a4ce
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh
@@ -0,0 +1,49 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#ifndef _hip_compat_cuh
+#define _hip_compat_cuh
+
+// Workaround for a bug in hipamd, backported from upstream.
+__device__ __forceinline__ __half __compat_hrcp(__half x) {
+ return __half_raw{
+ static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
+}
+
+__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
+ return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
+ static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
+}
+
+#define hrcp __compat_hrcp
+#define h2rcp __compat_h2rcp
+
+// Workaround for hipify_python using rocblas instead of hipblas.
+__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
+ hipblasOperation_t transA,
+ hipblasOperation_t transB,
+ int m,
+ int n,
+ int k,
+ const half* alpha,
+ const half* AP,
+ int lda,
+ const half* BP,
+ int ldb,
+ const half* beta,
+ half* CP,
+ int ldc) {
+ return hipblasHgemm(handle, transA, transB, m, n, k,
+ reinterpret_cast(alpha),
+ reinterpret_cast(AP), lda,
+ reinterpret_cast(BP), ldb,
+ reinterpret_cast(beta),
+ reinterpret_cast(CP), ldc);
+}
+
+#define rocblas_handle hipblasHandle_t
+#define rocblas_operation_none HIPBLAS_OP_N
+#define rocblas_get_stream hipblasGetStream
+#define rocblas_set_stream hipblasSetStream
+#define rocblas_hgemm __compat_hipblasHgemm
+
+#endif
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp b/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..bcc0e43901de85c3e361c01765ffab5b15de4da3
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp
@@ -0,0 +1,254 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include "util.cuh"
+#include "tuning.h"
+#include "cuda_buffers.cuh"
+#include "q4_matrix.cuh"
+#include "q4_matmul.cuh"
+#include "column_remap.cuh"
+
+// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a
+// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of
+// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console.
+
+void check_cuda(cudaError_t ret)
+{
+ switch (ret)
+ {
+ case cudaSuccess:
+ break;
+
+ case cudaUnspecified:
+ printf(" **** Unspecified error\n");
+ TORCH_CHECK(false, "CUDA error");
+ break;
+
+ default:
+ printf(" **** CUDA error\n"); \
+ printf(" **** %s\n", cudaGetErrorString(ret)); \
+ TORCH_CHECK(false, "CUDA error"); \
+ break;
+ }
+}
+
+// Some decluttering macros
+
+#define STRINGIFY_(__x) #__x
+#define STRINGIFY(__x) STRINGIFY_(__x)
+#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
+#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
+#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
+#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
+#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
+#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
+
+#define TORCH_CHECK_DEVICE_INDEX(__index) \
+do { \
+ TORCH_CHECK(__index >= 0, "no device index"); \
+ TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \
+} while(0)
+
+#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \
+do { \
+ TORCH_CHECK_DTYPE(__w, kInt); \
+ TORCH_CHECK_DTYPE(__w_scales, kHalf); \
+ TORCH_CHECK_DTYPE(__w_zeros, kInt); \
+ TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
+ TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
+ TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
+ TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
+} while(0)
+
+int get_groupsize(torch::Tensor w, torch::Tensor w_zeros)
+{
+ int groupsize = w.size(0) * 8 / w_zeros.size(0);
+ TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]")
+ return groupsize;
+}
+
+
+// Tuning parameters
+
+ExLlamaTuning tuningParams;
+
+void set_tuning_params
+(
+ int matmul_recons_thd,
+ bool matmul_fused_remap,
+ bool matmul_no_half2
+)
+{
+ tuningParams.matmul_recons_thd = matmul_recons_thd;
+ tuningParams.matmul_fused_remap = matmul_fused_remap;
+ tuningParams.matmul_no_half2 = matmul_no_half2;
+}
+
+
+// Release all unmanaged objects allocated by the extension
+
+void cleanup()
+{
+ cleanup_buffers_cuda();
+ g_q4_free_matrices();
+}
+
+
+// Prepare buffers for forward pass
+
+void prepare_buffers
+(
+ torch::Device device,
+ torch::Tensor temp_state,
+ torch::Tensor temp_dq
+)
+{
+ int device_index = device.index();
+ TORCH_CHECK_DEVICE_INDEX(device_index);
+ const at::cuda::OptionalCUDAGuard device_guard(device);
+
+ prepare_buffers_cuda
+ (
+ device_index,
+ // buffer size used for sanity checks
+ temp_state.numel(),
+ (half*) temp_state.data_ptr(),
+ (half*) temp_dq.data_ptr()
+ );
+}
+
+
+// Create Q4Matrix, return handle
+
+uintptr_t make_q4
+(
+ torch::Tensor qweight,
+ torch::Tensor qzeros,
+ torch::Tensor scales,
+ torch::Tensor g_idx,
+ int device
+)
+{
+ TORCH_CHECK_DTYPE(qweight, kInt);
+ TORCH_CHECK_DTYPE(qzeros, kInt);
+ TORCH_CHECK_DTYPE(scales, kHalf);
+ TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
+ TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
+ TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
+ TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
+
+ int width = qweight.size(1);
+ int height = qweight.size(0) * 8;
+ int groups = qzeros.size(0);
+
+ Q4Matrix* m = new Q4Matrix
+ (
+ height,
+ width,
+ groups,
+
+ (uint32_t*) qweight.data_ptr(),
+ (uint32_t*) qzeros.data_ptr(),
+ (half*) scales.data_ptr(),
+ g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(),
+
+ device
+ );
+
+ g_q4_keep_matrix(m);
+ return reinterpret_cast (m);
+}
+
+
+// Matmul half @ quant -> half
+
+void q4_matmul
+(
+ torch::Tensor x,
+ uintptr_t w,
+ torch::Tensor out
+)
+{
+ Q4Matrix* wm = reinterpret_cast (w);
+
+ TORCH_CHECK_DTYPE(x, kHalf);
+ TORCH_CHECK_DTYPE(out, kHalf);
+ TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
+ TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
+
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+
+ int x_height = x.size(0);
+
+ if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
+ {
+ q4_matmul_cuda
+ (
+ &tuningParams,
+ (half*) x.data_ptr(),
+ x_height,
+ wm,
+ (half*) out.data_ptr()
+ );
+ }
+ else
+ {
+ q4_matmul_recons_cuda
+ (
+ &tuningParams,
+ (half*) x.data_ptr(),
+ x_height,
+ wm,
+ (half*) out.data_ptr(),
+ at::cuda::getCurrentCUDABlasHandle()
+ );
+ }
+}
+
+
+// Remap columns in half tensor
+
+void column_remap
+(
+ torch::Tensor x,
+ torch::Tensor x_new,
+ torch::Tensor x_map
+)
+{
+ TORCH_CHECK_DTYPE(x, kHalf);
+ TORCH_CHECK_DTYPE(x_new, kHalf);
+ TORCH_CHECK_DTYPE(x_map, kInt);
+ TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
+
+ int height = x.size(0);
+ int width = x.size(1);
+
+ TORCH_CHECK_BUFFER_SIZE(x_new, height * width);
+
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+
+ column_remap_cuda
+ (
+ (half*) x.data_ptr(),
+ (half*) x_new.data_ptr(),
+ height,
+ width,
+ (uint32_t*) x_map.data_ptr()
+ );
+}
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
+ m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
+ m.def("cleanup", &cleanup, "cleanup");
+ m.def("make_q4", &make_q4, "make_q4");
+ m.def("q4_matmul", &q4_matmul, "q4_matmul");
+}
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh b/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..2fd5ab0b36cd0dd67c9b6081740bbed12711ee09
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh
@@ -0,0 +1,294 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#ifndef _matrix_cuh
+#define _matrix_cuh
+
+#include
+#include
+
+class MatrixView_half
+{
+public:
+ const half* data;
+ const int height;
+ const int width;
+
+ __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
+ : data(data), height(height), width(width)
+ { }
+
+ __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
+ __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
+ __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
+ __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
+};
+
+class MatrixView_half_rw
+{
+public:
+ half* data;
+ const int height;
+ const int width;
+
+ __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
+ : data(data), height(height), width(width)
+ { }
+
+ __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
+ __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
+ __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
+ __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
+ __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
+};
+
+class MatrixView_q4_row
+{
+public:
+ const uint32_t* data;
+ const int height;
+ const int width;
+
+ __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
+ : data(data), height(height), width(width)
+ { }
+
+ __device__ __forceinline__ int item(int row, int column) const
+ {
+ int shift = (column & 0x07) * 4;
+ return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
+ }
+};
+
+class MatrixView_q4_column
+{
+public:
+ const uint32_t* data;
+ const int height;
+ const int width;
+
+ __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
+ : data(data), height(height), width(width)
+ { }
+
+ __device__ __forceinline__ int item(int row, int column) const
+ {
+ int shift = (row & 0x07) * 4;
+ return (data[row / 8 * width + column] >> shift) & 0x0f;
+ }
+
+ __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
+ __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
+};
+
+// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu
+
+// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale
+
+__device__ __forceinline__ half2 dot_product_8
+(
+ const half2 acc,
+ MatrixView_half& h_,
+ const int h_row,
+ const int h_column, // divisible by 8
+ MatrixView_q4_column& v_,
+ const int v_row, // divisible by 8
+ const int v_column,
+ const half2 v_scale_2,
+ const uint32_t v_zero, // + 1 (!!)
+ const int count
+)
+{
+ const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
+ const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
+ half2 result = acc;
+
+ for (int i = 0; i < count; i++)
+ {
+ uint32_t v_read = *v_ptr; v_ptr += v_.width;
+
+ half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
+ half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
+ half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
+ half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
+ half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
+ half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
+ half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
+ half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
+
+ half2 v_01 = __halves2half2(v_0, v_1);
+ half2 v_23 = __halves2half2(v_2, v_3);
+ half2 v_45 = __halves2half2(v_4, v_5);
+ half2 v_67 = __halves2half2(v_6, v_7);
+
+// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently)
+// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff];
+// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff];
+// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ];
+
+ half2 tmp = __hmul2(*h_ptr++, v_01);
+ tmp = __hfma2(*h_ptr++, v_23, tmp);
+ tmp = __hfma2(*h_ptr++, v_45, tmp);
+ tmp = __hfma2(*h_ptr++, v_67, tmp);
+ result = __hfma2(v_scale_2, tmp, result);
+ }
+
+ return result;
+}
+
+__device__ __forceinline__ half dot_product_8_h
+(
+ const half acc,
+ MatrixView_half& h_,
+ const int h_row,
+ const int h_column, // divisible by 8
+ MatrixView_q4_column& v_,
+ const int v_row, // divisible by 8
+ const int v_column,
+ const half v_scale,
+ const uint32_t v_zero, // + 1 (!!)
+ const int count
+)
+{
+ const half* h_ptr = h_.item_ptr(h_row, h_column);
+ const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
+ half result = acc;
+
+ for (int i = 0; i < count; i++)
+ {
+ uint32_t v_read = *v_ptr; v_ptr += v_.width;
+
+ half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
+ half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
+ half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
+ half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
+ half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
+ half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
+ half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
+ half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
+
+ half tmp = __hmul(*h_ptr++, v_0);
+ tmp = __hfma(*h_ptr++, v_1, tmp);
+ tmp = __hfma(*h_ptr++, v_2, tmp);
+ tmp = __hfma(*h_ptr++, v_3, tmp);
+ tmp = __hfma(*h_ptr++, v_4, tmp);
+ tmp = __hfma(*h_ptr++, v_5, tmp);
+ tmp = __hfma(*h_ptr++, v_6, tmp);
+ tmp = __hfma(*h_ptr++, v_7, tmp);
+ result = __hfma(v_scale, tmp, result);
+ }
+
+ return result;
+}
+
+// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
+
+__device__ __forceinline__ half2 dot_product_8_x_map
+(
+ const half2 acc,
+ MatrixView_half& h_,
+ const int h_row,
+ const int h_column, // divisible by 8
+ MatrixView_q4_column& v_,
+ const int v_row, // divisible by 8
+ const int v_column,
+ const half2 v_scale_2,
+ const uint32_t v_zero, // + 1 (!!)
+ const int count,
+ const uint32_t* x_map
+)
+{
+ const half* h_ptr = h_.item_ptr(h_row, 0);
+ const uint32_t* x_map_ptr = x_map + h_column;
+ const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
+ half2 result = acc;
+
+ for (int i = 0; i < count; i++)
+ {
+ uint32_t v_read = *v_ptr; v_ptr += v_.width;
+
+ half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
+ half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
+ half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
+ half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
+ half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
+ half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
+ half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
+ half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
+
+ half2 v_01 = __halves2half2(v_0, v_1);
+ half2 v_23 = __halves2half2(v_2, v_3);
+ half2 v_45 = __halves2half2(v_4, v_5);
+ half2 v_67 = __halves2half2(v_6, v_7);
+
+ half h_0 = h_ptr[*x_map_ptr++];
+ half h_1 = h_ptr[*x_map_ptr++];
+ half h_2 = h_ptr[*x_map_ptr++];
+ half h_3 = h_ptr[*x_map_ptr++];
+ half h_4 = h_ptr[*x_map_ptr++];
+ half h_5 = h_ptr[*x_map_ptr++];
+ half h_6 = h_ptr[*x_map_ptr++];
+ half h_7 = h_ptr[*x_map_ptr++];
+
+ half2 h_01 = __halves2half2(h_0, h_1);
+ half2 h_23 = __halves2half2(h_2, h_3);
+ half2 h_45 = __halves2half2(h_4, h_5);
+ half2 h_67 = __halves2half2(h_6, h_7);
+
+ half2 tmp = __hmul2(h_01, v_01);
+ tmp = __hfma2(h_23, v_23, tmp);
+ tmp = __hfma2(h_45, v_45, tmp);
+ tmp = __hfma2(h_67, v_67, tmp);
+ result = __hfma2(v_scale_2, tmp, result);
+ }
+
+ return result;
+}
+
+__device__ __forceinline__ half dot_product_8_x_map_h
+(
+ const half acc,
+ MatrixView_half& h_,
+ const int h_row,
+ const int h_column, // divisible by 8
+ MatrixView_q4_column& v_,
+ const int v_row, // divisible by 8
+ const int v_column,
+ const half v_scale,
+ const uint32_t v_zero, // + 1 (!!)
+ const int count,
+ const uint32_t* x_map
+)
+{
+ const half* h_ptr = h_.item_ptr(h_row, 0);
+ const uint32_t* x_map_ptr = x_map + h_column;
+ const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
+ half result = acc;
+
+ for (int i = 0; i < count; i++)
+ {
+ uint32_t v_read = *v_ptr; v_ptr += v_.width;
+
+ half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
+ half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
+ half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
+ half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
+ half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
+ half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
+ half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
+ half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
+
+ half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
+ tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
+ tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
+ tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
+ tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
+ tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
+ tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
+ tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
+ result = __hfma(v_scale, tmp, result);
+ }
+
+ return result;
+}
+
+#endif
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu
new file mode 100644
index 0000000000000000000000000000000000000000..f47daeb0e8771a08a5dae1886fa571c9063fbb30
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu
@@ -0,0 +1,260 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#include "q4_matmul.cuh"
+#include "column_remap.cuh"
+#include "util.cuh"
+#include "matrix.cuh"
+#include "cu_compat.cuh"
+#include "cuda_buffers.cuh"
+#if defined(USE_ROCM)
+#include "hip_compat.cuh"
+#endif
+
+const int THREADS_X = 32; // Block size and thread count along columns in w and out
+const int THREADS_Y = 1; // Block size and thread count along rows in x and out
+
+typedef void (*fp_q4_matmul_kernel)
+(
+ const half*,
+ const uint32_t*,
+ half*,
+ const half*,
+ const uint32_t*,
+ const int,
+ const int,
+ const int,
+ const int,
+ const int,
+ const uint32_t*,
+ bool
+);
+
+template
+__global__ void q4_matmul_kernel
+(
+ const half* __restrict__ x,
+ const uint32_t* __restrict__ w,
+ half* __restrict__ out,
+ const half* __restrict__ w_scales,
+ const uint32_t* __restrict__ w_zeros,
+ const int height,
+ const int dim,
+ const int width,
+ const int groupsize,
+ const int block_size_z,
+ const uint32_t* __restrict__ x_map,
+ bool no_zero
+)
+{
+ // Start of block
+
+ int x_column = block_size_z * blockIdx.z;
+ int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));
+
+ int w_column = THREADS_X * blockIdx.x + threadIdx.x;
+ int x_row = THREADS_Y * blockIdx.y + threadIdx.y;
+
+ int iterations = (x_column_end - x_column) / 8;
+
+ // Views
+
+ MatrixView_half x_(x, height, dim);
+ MatrixView_half w_scales_(w_scales, dim / groupsize, width);
+ MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width);
+ MatrixView_q4_column w_(w, dim, width);
+ MatrixView_half_rw out_(out, height, width);
+
+ // Zero output
+
+ if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0)
+ {
+ *((uint32_t*) out_.item_ptr(x_row, w_column)) = 0;
+ __syncthreads();
+ }
+
+ // Loop over part of x row (and w column)
+
+ half2 acc = {};
+ half acc_h = {};
+
+ if constexpr (use_groupsize)
+ {
+ // For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this
+ // could be slightly faster
+
+ for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize)
+ {
+ if constexpr (use_half2)
+ {
+ half2 w_scale = w_scales_.item_half2half2(group, w_column);
+ uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
+
+ if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
+ else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
+ }
+ else
+ {
+ half w_scale = w_scales_.item(group, w_column);
+ uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
+
+ if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
+ else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
+ }
+ }
+ }
+ else
+ {
+ // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache
+
+ for (int k = x_column; k < x_column + iterations * 8; k += 8)
+ {
+ if constexpr (use_half2)
+ {
+ int group = k / groupsize;
+ half2 w_scale = w_scales_.item_half2half2(group, w_column);
+ uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
+
+ if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
+ else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
+ }
+ else
+ {
+ int group = k / groupsize;
+ half w_scale = w_scales_.item(group, w_column);
+ uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
+
+ if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
+ else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
+ }
+ }
+ }
+
+ // Add to block result
+
+ if constexpr (use_half2)
+ {
+ half result = __hadd(__low2half(acc), __high2half(acc));
+ atomicAdd(out_.item_ptr(x_row, w_column), result);
+ }
+ else
+ {
+ atomicAdd(out_.item_ptr(x_row, w_column), acc_h);
+ }
+}
+
+fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map)
+{
+ //
+ if (tuningParams->matmul_no_half2) {
+ if (block_size_z % groupsize == 0) {
+ if (x_map) return q4_matmul_kernel;
+ else return q4_matmul_kernel;
+ } else {
+ if (x_map) return q4_matmul_kernel;
+ else return q4_matmul_kernel;
+ }
+ } else {
+ if (block_size_z % groupsize == 0)
+ {
+ if (x_map) return q4_matmul_kernel;
+ else return q4_matmul_kernel;
+ } else {
+ if (x_map) return q4_matmul_kernel;
+ else return q4_matmul_kernel;
+ }
+ }
+};
+
+// Compute y = x @ w
+
+void q4_matmul_cuda
+(
+ ExLlamaTuning* tuningParams,
+ const half* x,
+ const int x_height,
+ const Q4Matrix* w,
+ half* out,
+ bool no_zero,
+ cudaStream_t alt_stream
+)
+{
+ int height = x_height;
+ int dim = w->height;
+ int width = w->width;
+
+ cudaSetDevice(w->device);
+
+ uint32_t* x_map = w->cuda_x_map;
+ const half* x_mapped = x;
+ if (x_map && !tuningParams->matmul_fused_remap && !alt_stream)
+ {
+ CudaBuffers* buffers = get_buffers(w->device);
+ column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
+ x_mapped = buffers->temp_state;
+ x_map = NULL;
+ }
+
+ int block_size_z;
+ if (w->width == 4096) block_size_z = 384; // 7B
+ else if (w->width == 11008) block_size_z = 256;
+ else if (w->width == 5120) block_size_z = 384; // 13B
+ else if (w->width == 13824) block_size_z = 256;
+ else if (w->width == 6656) block_size_z = 256; // 33B
+ else if (w->width == 17920) block_size_z = 128;
+ else block_size_z = 256;
+
+ //if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half));
+
+ dim3 threads(THREADS_X, THREADS_Y, 1);
+
+ dim3 blocks
+ (
+ (width + threads.x - 1) / threads.x,
+ (height + threads.y - 1) / threads.y,
+ (dim + block_size_z - 1) / block_size_z
+ );
+
+ fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);
+
+ kernel<<>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
+}
+
+void q4_matmul_recons_cuda
+(
+ ExLlamaTuning* tuningParams,
+ const half* x,
+ const int x_height,
+ Q4Matrix* w,
+ half* out,
+ const cublasHandle_t handle,
+ bool no_zero
+)
+{
+ int height = x_height;
+ int dim = w->height;
+ int width = w->width;
+
+ cudaSetDevice(w->device);
+ CudaBuffers* buffers = get_buffers(w->device);
+
+ const half* x_mapped = x;
+ if (w->cuda_x_map)
+ {
+ TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "temp_state buffer is too small");
+ column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
+ x_mapped = buffers->temp_state;
+ }
+
+ w->reconstruct(buffers->temp_dq);
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
+ const float alpha = 1.0f;
+ const float beta = no_zero ? 1.0f : 0.0f;
+ cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width,
+ x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width);
+#else
+ const half alpha = __float2half(1.0f);
+ const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f);
+ cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width);
+#endif
+}
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..09f3e1a633628fe1584c65b1db61ab1d3570f3da
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh
@@ -0,0 +1,43 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#ifndef _q4_matmul_cuh
+#define _q4_matmul_cuh
+
+#include
+#include
+#include
+#include
+#include
+
+#include "q4_matrix.cuh"
+#include "tuning.h"
+
+// Workaround for hipify_python using rocblas instead of hipblas.
+#if defined(USE_ROCM)
+#include
+#define rocblas_handle hipblasHandle_t
+#endif
+
+void q4_matmul_cuda
+(
+ ExLlamaTuning* tuningParams,
+ const half* x,
+ const int x_height,
+ const Q4Matrix* w,
+ half* out,
+ bool no_zero = false,
+ cudaStream_t alt_stream = NULL
+);
+
+void q4_matmul_recons_cuda
+(
+ ExLlamaTuning* tuningParams,
+ const half* x,
+ const int x_height,
+ Q4Matrix* w,
+ half* out,
+ const cublasHandle_t handle,
+ bool no_zero = false
+);
+
+#endif
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu
new file mode 100644
index 0000000000000000000000000000000000000000..9c61143f565e0b4b75fc68ebbfec93a692389931
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu
@@ -0,0 +1,225 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#include "q4_matrix.cuh"
+#include
+#include "util.cuh"
+#include "matrix.cuh"
+
+using namespace std;
+
+const int UNSHUF_BLOCKSIZE_X = 64;
+
+const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column
+const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows
+
+vector g_q4_matrices;
+
+void g_q4_keep_matrix(Q4Matrix* m)
+{
+ g_q4_matrices.push_back(m);
+}
+
+void g_q4_free_matrices()
+{
+ for (const auto& m : g_q4_matrices) delete m;
+ g_q4_matrices.clear();
+}
+
+Q4Matrix::Q4Matrix
+(
+ const int _height,
+ const int _width,
+ const int _groups,
+
+ uint32_t* _qweight,
+ uint32_t* _qzeros,
+ half* _scales,
+ uint32_t* _g_idx,
+
+ const int _device
+) :
+ height(_height),
+ width(_width),
+ groups(_groups),
+ device(_device)
+{
+ cudaSetDevice(device);
+
+ cuda_qweight = _qweight;
+ cuda_qzeros = _qzeros;
+ cuda_scales = _scales;
+
+ groupsize = height / groups;
+
+ if (_g_idx) make_sequential(_g_idx);
+}
+
+Q4Matrix::~Q4Matrix()
+{
+}
+
+// Make sequential
+
+__global__ void make_sequential_kernel
+(
+ const uint32_t* __restrict__ w,
+ uint32_t* __restrict__ w_new,
+ const uint32_t* __restrict__ x_map,
+ const int w_height,
+ const int w_width
+)
+{
+ const uint64_t* w2 = (uint64_t*) w;
+ uint64_t* w_new2 = (uint64_t*) w_new;
+ int w2_stride = w_width >> 1;
+
+ int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
+ if (w2_column >= w2_stride) return;
+
+ int w_new2_row = blockIdx.y;
+
+ int x_map_idx = w_new2_row << 3;
+
+ uint64_t dst = 0;
+
+ #pragma unroll
+ for (int i = 0; i < 8; i++)
+ {
+ int source_row = x_map[x_map_idx++];
+
+ int w2_row = source_row >> 3;
+ int w2_subrow = source_row & 0x07;
+ int w2_row_shift = w2_subrow << 2;
+ int wnew2_row_shift = i << 2;
+
+ uint64_t src = w2[w2_row * w2_stride + w2_column];
+ src >>= w2_row_shift;
+ src &= 0x0000000f0000000f;
+ src <<= wnew2_row_shift;
+ dst |= src;
+ }
+
+ w_new2[w_new2_row * w2_stride + w2_column] = dst;
+}
+
+void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
+{
+ uint32_t* cuda_new_qweight = NULL;
+ cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
+ cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch
+
+ uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
+ uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
+ uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
+
+ // Group histogram
+
+ for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
+
+ // Group map
+
+ for (int i = 0, acc = 0; i < groups; i++)
+ {
+ short tmp = cpu_g_idx_map[i];
+ cpu_g_idx_map[i] = acc;
+ acc += tmp;
+ }
+
+ // X map (inverse)
+
+ for (int row = 0; row < height; row++)
+ {
+ uint32_t target_group = cpu_g_idx[row];
+ uint32_t target_row = cpu_g_idx_map[target_group];
+ cpu_g_idx_map[target_group]++;
+ cpu_x_map_inv[row] = target_row;
+ }
+
+ // X map
+
+ for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
+
+ // Move to CUDA
+
+ cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice);
+
+ // Rearrange rows in w
+
+ dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);
+ dim3 blocks
+ (
+ (width + UNSHUF_BLOCKSIZE_X * 2 - 1) / (UNSHUF_BLOCKSIZE_X * 2),
+ height / 8,
+ 1
+ );
+
+ make_sequential_kernel<<>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);
+
+ // Replace qweights
+
+ cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
+
+ // Cleanup
+
+ cudaDeviceSynchronize();
+ cudaFree(cuda_new_qweight);
+ free(cpu_g_idx_map);
+ free(cpu_x_map);
+ free(cpu_x_map_inv);
+}
+
+__global__ void reconstruct_kernel
+(
+ const uint32_t* __restrict__ w,
+ half* __restrict__ out, // (y)
+ const half* __restrict__ w_scales,
+ const uint32_t* __restrict__ w_zeros,
+ const int height,
+ const int width,
+ const int groupsize
+)
+{
+ // Start of block
+
+ int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x;
+ int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8;
+ if (column >= width) return;
+
+ // Views
+
+ MatrixView_q4_column w_(w, height, width);
+ MatrixView_half_rw out_(out, height, width);
+ MatrixView_half w_scales_(w_scales, height / groupsize, width);
+ MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width);
+
+ // Groupsize version
+
+ int group = row / groupsize;
+
+ half w_scale = w_scales_.item(group, column);
+ uint32_t w_zero = w_zeros_.item(group, column) + 1;
+
+ uint32_t w_read = w_.item_uint32_t(row, column);
+ half* out_ptr = out_.item_ptr(row, column);
+
+ #pragma unroll
+ for (int s = 0; s < 32; s += 4)
+ {
+ half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
+ *out_ptr = w_item; out_ptr += out_.width;
+ }
+}
+
+void Q4Matrix::reconstruct(half* out)
+{
+ dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1);
+
+ dim3 blocks
+ (
+ (width + threads.x - 1) / threads.x,
+ (height / 8 + threads.y - 1) / threads.y,
+ 1
+ );
+
+ reconstruct_kernel<<>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
+}
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..50cb72a41518593f6be4dded4f03c61772ef2ef9
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh
@@ -0,0 +1,53 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#ifndef _q4_matrix_cuh
+#define _q4_matrix_cuh
+
+#include
+#include
+#include
+
+class Q4Matrix
+{
+public:
+
+ int device;
+
+ int height;
+ int width;
+ int groups;
+ int groupsize;
+
+ uint32_t* cuda_qweight = NULL;
+ uint32_t* cuda_qzeros = NULL;
+ half* cuda_scales = NULL;
+ uint32_t* cuda_x_map = NULL;
+
+ Q4Matrix
+ (
+ const int _height,
+ const int _width,
+ const int _groups,
+
+ uint32_t* _qweight,
+ uint32_t* _qzeros,
+ half* _scales,
+ uint32_t* _g_idx,
+
+ const int _device
+ );
+
+ ~Q4Matrix();
+
+ void reconstruct(half* out);
+
+private:
+
+ void make_sequential(const uint32_t* cpu_g_idx);
+
+};
+
+void g_q4_keep_matrix(Q4Matrix* m);
+void g_q4_free_matrices();
+
+#endif
\ No newline at end of file
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/tuning.h b/colossalai/kernel/cuda_native/csrc/gptq/tuning.h
new file mode 100644
index 0000000000000000000000000000000000000000..e413b8a96c11afd7793128d27145325d26c63715
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/tuning.h
@@ -0,0 +1,12 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#ifndef _tuning_h
+#define _tuning_h
+
+struct ExLlamaTuning {
+ int matmul_recons_thd;
+ bool matmul_fused_remap;
+ bool matmul_no_half2;
+};
+
+#endif
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/util.cuh b/colossalai/kernel/cuda_native/csrc/gptq/util.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..7b397573214b2b1f1af50a82320f61aabee5c8f1
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/gptq/util.cuh
@@ -0,0 +1,33 @@
+// Adapted from turboderp exllama: https://github.com/turboderp/exllama
+
+#ifndef _util_cuh
+#define _util_cuh
+
+#include
+#include
+#include
+#include
+
+#if defined(USE_ROCM)
+#define cudaUnspecified hipErrorUnknown
+#else
+#define cudaUnspecified cudaErrorApiFailureBase
+#endif
+
+// React to failure on return code != cudaSuccess
+
+#define _cuda_check(fn) \
+do { \
+ {_cuda_err = fn;} \
+ if (_cuda_err != cudaSuccess) goto _cuda_fail; \
+} while(false)
+
+// React to failure on return code == 0
+
+#define _alloc_check(fn) \
+do { \
+ if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \
+ else _cuda_err = cudaSuccess; \
+} while(false)
+
+#endif
diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu
index 26efa2ad6f31632a4e7ceddd06745b067759bb43..9a6a8ebc39839228630cb12967a7b2a8048b2b6b 100644
--- a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu
+++ b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu
@@ -1,7 +1,6 @@
#include
#include
-
#include "cuda_util.h"
/* GPU function guard */
diff --git a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu
index a39a6dae0f7fb6968e6ee65fde8db4bbc5d61ab0..ce0b017f12e1e25d262cf377320966ca5044eb04 100644
--- a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu
+++ b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu
@@ -1,1002 +1,1002 @@
-#include
-#include
-
-#include "kernels.h"
-
-#include
-
-
-namespace cg = cooperative_groups;
-
-curandStatePhilox4_32_10_t *curandstate;
-
-/**
- * @brief element-wise activation function on device, like Relu, Gelu
- *
- * @tparam enum class ActivationType, kRelu, kGelu
- * @tparam input type
- * @param any shape of float and __half2
- * @return same shape and type with input
- */
-template
-__forceinline__ __device__ T activation_kernel(T x);
-
-template <>
-__device__ float activation_kernel(float x) {
- float cdf =
- 0.5f *
- (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x))));
- return x * cdf;
-}
-
-template <>
-__device__ __half2
-activation_kernel(__half2 val) {
- __half2 val_pow3 = __hmul2(val, __hmul2(val, val));
- float2 tmp_pow = __half22float2(val_pow3);
- float2 tmp = __half22float2(val);
-
- tmp.x =
- 0.5f *
- (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x))));
- tmp.y =
- 0.5f *
- (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y))));
- return __hmul2(val, __float22half2_rn(tmp));
-}
-
-template <>
-__device__ float activation_kernel(float x) {
- return fmaxf(x, 0);
-}
-
-template <>
-__device__ __half2
-activation_kernel(__half2 x) {
- return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)),
- fmaxf(0.f, __half2float(x.y)));
-}
-
-/**
- * @brief element-wise activation backward function on device
- *
- * @tparam enum class ActivationType
- * @tparam input type
- * @param any shape of float and __half2
- * @return same shape of input
- */
-template
-__forceinline__ __device__ T activation_bwd_kernel(T grad, T x);
-
-template <>
-__device__ float activation_bwd_kernel(float grad,
- float x) {
- const float sqrt_param = 0.79788456080286535587989211986876f;
- const float mul_param = 0.044715;
-
- float x2mul = x * x * mul_param;
- float tan_h = tanhf(sqrt_param * (x + x * x2mul));
- float dg1 = 0.5f * (1.0f + tan_h);
- float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
- float dg3 = dg2 * 3 * x2mul;
- return grad * (dg1 + dg2 + dg3);
-}
-
-template <>
-__device__ __half activation_bwd_kernel(
- __half grad, __half x_half) {
- float x = __half2float(x_half);
- const float sqrt_param = 0.79788456080286535587989211986876f;
- const float mul_param = 0.044715;
-
- float x2mul = x * x * mul_param;
- float tan_h = tanhf(sqrt_param * (x + x * x2mul));
- float dg1 = 0.5f * (1.0f + tan_h);
- float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
- float dg3 = dg2 * 3 * x2mul;
- return grad * __float2half(dg1 + dg2 + dg3);
-}
-
-template <>
-__device__ float activation_bwd_kernel(float grad,
- float x) {
- return x > 0.f ? grad : 0.f;
-}
-
-template <>
-__device__ __half
-activation_bwd_kernel(__half grad, __half x) {
- const __half half_zero = __float2half(0.f);
- return x > half_zero ? grad : half_zero;
-}
-
-template <>
-__device__ __half2 activation_bwd_kernel(
- __half2 grad2, __half2 x_half2) {
- const __half half_zero = __float2half(0.f);
- return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero,
- x_half2.y > half_zero ? grad2.y : half_zero);
-}
-
-/**
- * @brief init curand states in global memory
- *
- * @thread grid_dim * block*dim to suuport any size of states
- * @param state persistant curand states
- * @param seed seed to init states
- * @return void
- */
-__global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state,
- int seed) {
- /* Each thread gets same seed, a different sequence
- number, no offset */
- int id = threadIdx.x + blockIdx.x * blockDim.x;
- curand_init(seed, id, 0, &state[id]);
-}
-
-void launch_curand_init(int total_count, int dim, cudaStream_t stream) {
- cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t));
- int grid_dim = total_count >> 9;
- curand_init_kernel<<>>(
- curandstate, std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count());
-}
-
-/**
- * @brief element-wise dropout, store dropped position in mask, it's not
- * in-place
- *
- * @thread
- * gridDim.x = total_count / 1024
- * blockDim.x = 1024
- *
- * @param total_count total elements
- * @param ratio drop ratio
- * @param out any size of float and __half
- * @param in same with out
- * @param mask uint8 type, same size with out
- * @param seed seed to curand
- * @return void
- */
-__global__ void ls_dropout_kernel(const int total_count, const float ratio,
- float *__restrict__ out,
- const float *__restrict__ in,
- uint8_t *__restrict__ mask, const int seed) {
- const float scale = 1.f / (1.f - ratio);
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 4 >= total_count) return;
-
- curandStatePhilox4_32_10_t state;
- curand_init(seed, i, 0, &state);
- uint8_t m[4];
-
- float4 *out4 = reinterpret_cast(out);
- const float4 *data4 = reinterpret_cast(in);
- uint32_t *mask4 = reinterpret_cast(mask);
- float4 rand = curand_uniform4(&state);
-
- m[0] = (uint8_t)(rand.x > ratio);
- m[1] = (uint8_t)(rand.y > ratio);
- m[2] = (uint8_t)(rand.z > ratio);
- m[3] = (uint8_t)(rand.w > ratio);
-
- uint32_t *m4 = reinterpret_cast(m);
- mask4[i] = m4[0];
-
- float4 input4 = data4[i];
- float4 res4;
- res4.x = input4.x * scale * m[0];
- res4.y = input4.y * scale * m[1];
- res4.z = input4.z * scale * m[2];
- res4.w = input4.w * scale * m[3];
- out4[i] = res4;
-}
-
-__global__ void ls_dropout_kernel(const int total_count, const float ratio,
- __half *__restrict__ out,
- const __half *__restrict__ in,
- uint8_t *__restrict__ mask, const int seed) {
- const float scale = 1.f / (1.f - ratio);
-
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 8 >= total_count) return;
-
- curandStatePhilox4_32_10_t state;
- curand_init(seed, i, 0, &state);
-
- const float4 *vals_float4 = reinterpret_cast(in);
- float4 *outs_float4 = reinterpret_cast(out);
- uint64_t *mask8 = reinterpret_cast(mask);
-
- uint8_t m[8];
- float4 rand = curand_uniform4(&state);
- m[0] = (uint8_t)(rand.x > ratio);
- m[1] = (uint8_t)(rand.y > ratio);
- m[2] = (uint8_t)(rand.z > ratio);
- m[3] = (uint8_t)(rand.w > ratio);
- rand = curand_uniform4(&state);
- m[4] = (uint8_t)(rand.x > ratio);
- m[5] = (uint8_t)(rand.y > ratio);
- m[6] = (uint8_t)(rand.z > ratio);
- m[7] = (uint8_t)(rand.w > ratio);
- uint64_t *m8 = reinterpret_cast(m);
- mask8[i] = *m8;
-
- float4 val_float4 = vals_float4[i];
- float4 out_float4;
- __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
- __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
- __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]);
- __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]);
- __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]);
- __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]);
- out_half2[0] = __hmul2(val_half2[0], scale_mask_1);
- out_half2[1] = __hmul2(val_half2[1], scale_mask_2);
- out_half2[2] = __hmul2(val_half2[2], scale_mask_3);
- out_half2[3] = __hmul2(val_half2[3], scale_mask_4);
- outs_float4[i] = out_float4;
-}
-
-/**
- * @brief element-wise dropout backward with dropout mask, it's
- * not in-place
- *
- * @thread
- * gridDim.x = total_count / 1024
- * blockDim.x = 1024
- *
- * @param total_count total elements
- * @param ratio drop ratio
- * @param in any size of float and __half
- * @param mask uint8 type, same size with in
- * @return void
- */
-__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
- float *out, const float *in,
- const uint8_t *__restrict__ mask) {
- const float scale = 1.f / (1.f - ratio);
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 4 >= total_count) return;
-
- uint8_t m[4];
-
- float4 *out4 = reinterpret_cast(out);
- const float4 *in4 = reinterpret_cast(in);
- const uint32_t *mask4 = reinterpret_cast(mask);
-
- uint32_t *m4 = reinterpret_cast(m);
- m4[0] = mask4[i];
-
- float4 input4 = in4[i];
- float4 res4;
- res4.x = input4.x * scale * static_cast(m[0]);
- res4.y = input4.y * scale * static_cast(m[1]);
- res4.z = input4.z * scale * static_cast(m[2]);
- res4.w = input4.w * scale * static_cast(m[3]);
- out4[i] = res4;
-}
-
-__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
- __half *out, const __half *in,
- const uint8_t *__restrict__ mask) {
- const __half scale = 1.f / (1.f - ratio);
-
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 8 >= total_count) return;
-
- float4 *out4 = reinterpret_cast(out);
- const float4 *vals_float4 = reinterpret_cast(in);
- const uint64_t *mask8 = reinterpret_cast(mask);
-
- uint8_t m[8];
- uint64_t *m8 = reinterpret_cast(m);
- m8[0] = mask8[i];
-
- float4 val_float4 = vals_float4[i];
- float4 out_float4;
- __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
- __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
- __half2 scale_mask_1 =
- __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1]));
- __half2 scale_mask_2 =
- __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3]));
- __half2 scale_mask_3 =
- __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5]));
- __half2 scale_mask_4 =
- __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7]));
- out_half2[0] = __hmul2(val_half2[0], scale_mask_1);
- out_half2[1] = __hmul2(val_half2[1], scale_mask_2);
- out_half2[2] = __hmul2(val_half2[2], scale_mask_3);
- out_half2[3] = __hmul2(val_half2[3], scale_mask_4);
- out4[i] = out_float4;
-}
-
-template <>
-void launch_ls_dropout(float *out, const float *vals, uint8_t *mask,
- int total_count, float ratio, cudaStream_t stream,
- bool backward) {
- int grid_dim = total_count >> 12;
- if (!backward) {
- ls_dropout_kernel<<>>(
- total_count, ratio, out, vals, mask,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count());
- } else {
- ls_dropout_bwd_kernel<<>>(total_count, ratio,
- out, vals, mask);
- }
-}
-
-template <>
-void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask,
- int total_count, float ratio,
- cudaStream_t stream, bool backward) {
- int grid_dim = total_count >> 13;
- if (!backward) {
- ls_dropout_kernel<<>>(
- total_count, ratio, out, vals, mask,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count());
- } else {
- ls_dropout_bwd_kernel<<>>(total_count, ratio,
- out, vals, mask);
- }
-}
-
-/**
- * @brief fused bias, dropout, and residual at the end of Attention and FFN,
- * store dropped position in mask, it's not in-place
- *
- * @thread
- * gridDim.x = total_count / 1024
- * blockDim.x = 1024
- *
- * @param total_count total elements
- * @param ratio drop ratio
- * @param out [batch_size, seq_len, hidden_size], float and __half
- * @param in [batch_size, seq_len, hidden_size], float and __half
- * @param mask [batch_size, seq_len, hidden_size], uint8 type
- * @param bias [hidden_size], ffn bias
- * @param residual [batch_size, seq_len, hidden_size], float and __half
- * @param seed seed to curand
- * @param hidden_size hidden size
- * @return void
- */
-__global__ void ls_dropout_res_bias_kernel(
- const int total_count, const float ratio, float *__restrict__ out,
- const float *__restrict__ in, uint8_t *__restrict__ mask,
- const float *__restrict__ bias, const float *__restrict__ residual,
- const int seed, const int hidden_size) {
- const float scale = 1.f / (1.f - ratio);
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 4 >= total_count) return;
-
- curandStatePhilox4_32_10_t state;
- curand_init(seed, i, 0, &state);
- uint8_t m[4];
-
- float4 *out4 = reinterpret_cast(out);
- const float4 *data4 = reinterpret_cast(in);
- const float4 *residual4 = reinterpret_cast(residual);
- const float4 *bias4 = reinterpret_cast(bias);
- uint32_t *mask4 = reinterpret_cast(mask);
- float4 rand = curand_uniform4(&state);
-
- m[0] = static_cast(rand.x > ratio);
- m[1] = static_cast(rand.y > ratio);
- m[2] = static_cast(rand.z > ratio);
- m[3] = static_cast(rand.w > ratio);
-
- int bias_i = i % (hidden_size >> 2);
- uint32_t *m4 = reinterpret_cast(m);
- mask4[i] = m4[0];
- const float4 input4 = data4[i];
- const float4 b4 = __ldg(&bias4[bias_i]);
- const float4 res4 = residual4[i];
- float4 output4;
-
- output4.x = (input4.x + b4.x) * scale * m[0] + res4.x;
- output4.y = (input4.y + b4.y) * scale * m[1] + res4.y;
- output4.z = (input4.z + b4.z) * scale * m[2] + res4.z;
- output4.w = (input4.w + b4.w) * scale * m[3] + res4.w;
-
- out4[i] = output4;
-}
-
-__global__ void ls_dropout_res_bias_kernel(
- const int total_count, const float ratio, __half *__restrict__ out,
- const __half *__restrict__ in, uint8_t *__restrict__ mask,
- const __half *__restrict__ bias, const __half *__restrict__ residual,
- const int seed, const int hidden_size) {
- const __half scale = 1. / (1. - ratio);
-
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 8 >= total_count) return;
-
- curandStatePhilox4_32_10_t state;
- curand_init(seed, i, 0, &state);
-
- const float4 *vals_float4 = reinterpret_cast(in);
- float4 *outs_float4 = reinterpret_cast(out);
- const float4 *residual4 = reinterpret_cast(residual);
- const float4 *bias4 = reinterpret_cast(bias);
- uint64_t *mask8 = reinterpret_cast(mask);
-
- uint8_t m[8];
- float4 rand = curand_uniform4(&state);
- m[0] = static_cast(rand.x > ratio);
- m[1] = static_cast(rand.y > ratio);
- m[2] = static_cast(rand.z > ratio);
- m[3] = static_cast(rand.w > ratio);
- rand = curand_uniform4(&state);
- m[4] = static_cast(rand.x > ratio);
- m[5] = static_cast(rand.y > ratio);
- m[6] = static_cast(rand.z > ratio);
- m[7] = static_cast(rand.w > ratio);
- uint64_t *m8 = reinterpret_cast(m);
- mask8[i] = m8[0];
-
- int bias_i = i % (hidden_size >> 3);
- float4 val_float4 = vals_float4[i];
- const float4 b4 = __ldg(&bias4[bias_i]);
- const float4 res4 = residual4[i];
- float4 out_float4;
-
- __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
- __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
- const __half2 *b_half2 = reinterpret_cast(&b4);
- const __half2 *res_half2 = reinterpret_cast(&res4);
- __half2 scale_mask_1 =
- __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1]));
- __half2 scale_mask_2 =
- __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3]));
- __half2 scale_mask_3 =
- __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5]));
- __half2 scale_mask_4 =
- __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7]));
- out_half2[0] =
- __hfma2(__hadd2(val_half2[0], b_half2[0]), scale_mask_1, res_half2[0]);
- out_half2[1] =
- __hfma2(__hadd2(val_half2[1], b_half2[1]), scale_mask_2, res_half2[1]);
- out_half2[2] =
- __hfma2(__hadd2(val_half2[2], b_half2[2]), scale_mask_3, res_half2[2]);
- out_half2[3] =
- __hfma2(__hadd2(val_half2[3], b_half2[3]), scale_mask_4, res_half2[3]);
- outs_float4[i] = out_float4;
-}
-
-template <>
-void launch_ls_dropout_res_bias(float *out, const float *vals,
- uint8_t *mask, const float *bias,
- const float *residual, int total_count,
- int dim, float ratio,
- cudaStream_t stream) {
- int grid_dim = total_count >> 12;
- ls_dropout_res_bias_kernel<<>>(
- total_count, ratio, out, vals, mask, bias, residual,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count(),
- dim);
-}
-
-template <>
-void launch_ls_dropout_res_bias<__half>(__half *out, const __half *vals,
- uint8_t *mask, const __half *bias,
- const __half *residual, int total_count,
- int dim, float ratio,
- cudaStream_t stream) {
- int grid_dim = total_count >> 13;
- ls_dropout_res_bias_kernel<<>>(
- total_count, ratio, out, vals, mask, bias, residual,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count(),
- dim);
-}
-
-/**
- * @brief fused bias and dropout backward at the end of Attention and FFN
- *
- * @thread
- * gridDim.x = hidden_size / 8
- * blockDim.x = 8
- * blockDim.y = 1024 / 8 = 128
- *
- * @param row_size batch_size * seq_len
- * @param ratio dropout ratio
- * @param in_grad [batch_size, seq_len, hidden_size], input grad
- * @param bias_grad [hidden_size], bias grad
- * @param out_grad [batch_size, seq_len, hidden_size], output grad
- * @param mask [batch_size, seq_len, hidden_size], dropout mask
- * @param hidden_size
- * @return void
- */
-__global__ void ls_dropout_bias_bwd_kernel(
- const int row_size, const float ratio, float *__restrict__ in_grad,
- float *__restrict__ bias_grad, const float *__restrict__ out_grad,
- const uint8_t *__restrict__ mask, const int hidden_size) {
- const float scale = 1.f / (1.f - ratio);
- // every block generate 8 bias result
- __shared__ float tile[8][129];
-
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile g = cg::tiled_partition(b);
-
- int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8);
- int stride = hidden_size * 128;
- float local_sum = 0;
-
- int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
- for (int r = threadIdx.y; r < row_size; r += 128) {
- float val = out_grad[idx];
- val *= scale * static_cast(mask[idx]);
- local_sum += val;
- in_grad[idx] = val;
- idx += stride;
- }
-
- tile[threadIdx.x][threadIdx.y] = local_sum;
- __syncthreads();
-
- float sum = 0;
- int tid = threadIdx.y * blockDim.x + threadIdx.x;
- int x = tid >> 7;
- int y = tid & (127);
- if (y < 32) {
-#pragma unroll
- for (int i = 0; i < 4; i++) {
- sum += tile[x][y + i * 32];
- }
- }
- __syncthreads();
-
- for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i);
-
- if (y == 0) tile[0][x] = sum;
- __syncthreads();
-
- if (threadIdx.x < 8) {
- int pos = flat_2dim(blockIdx.x, threadIdx.x, 8);
- bias_grad[pos] = tile[0][threadIdx.x];
- }
-}
-
-__global__ void ls_dropout_bias_bwd_kernel(
- const int row_size, const float ratio, __half *__restrict__ in_grad,
- __half *__restrict__ bias_grad, const __half *__restrict__ out_grad,
- const uint8_t *__restrict__ mask, const int hidden_size) {
- const __half2 scale = __float2half2_rn(1.f / (1.f - ratio));
- __shared__ __half2 tile[8][129];
-
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile g = cg::tiled_partition(b);
-
- __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad);
- const __half2 *out_grad2 = reinterpret_cast(out_grad);
- __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad);
-
- int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8);
- int stride = hidden_size * 128;
- __half2 local_sum = __float2half2_rn(0.f);
-
- int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
- for (int r = threadIdx.y; r < row_size; r += 128) {
- __half2 val = out_grad2[idx];
- __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]);
- val *= scale * m2;
- local_sum += val;
- in_grad2[idx] = val;
- idx += stride;
- }
-
- tile[threadIdx.x][threadIdx.y] = local_sum;
- __syncthreads();
-
- __half2 sum = __float2half2_rn(0.f);
- int tid = threadIdx.y * blockDim.x + threadIdx.x;
- int x = tid >> 7;
- int y = tid & (127);
- if (y < 32) {
-#pragma unroll
- for (int i = 0; i < 4; i++) {
- sum += tile[x][y + i * 32];
- }
- }
- __syncthreads();
-
- for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
-
- if (y == 0) tile[0][x] = sum;
- __syncthreads();
-
- if (threadIdx.x < 8) {
- int pos = flat_2dim(blockIdx.x, threadIdx.x, 8);
- bias_grad2[pos] = tile[0][threadIdx.x];
- }
-}
-
-template
-void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad,
- const uint8_t *mask, int row_size, int dim,
- float ratio, cudaStream_t stream) {
- dim3 grid_dim((dim - 1) / 8 + 1);
- dim3 block_dim(8, 128);
- ls_dropout_bias_bwd_kernel<<>>(
- row_size, ratio, in_grad, bias_grad, out_grad, mask, dim);
-}
-
-template <>
-void launch_ls_dropout_bias_bwd(__half *in_grad, __half *bias_grad,
- const __half *out_grad, const uint8_t *mask,
- int row_size, int dim, float ratio,
- cudaStream_t stream) {
- dim >>= 1;
- dim3 grid_dim((dim - 1) / 8 + 1);
- dim3 block_dim(8, 128);
- ls_dropout_bias_bwd_kernel<<>>(
- row_size, ratio, in_grad, bias_grad, out_grad, mask, dim);
-}
-
-template void launch_ls_dropout_bias_bwd(float *in_grad, float *bias_grad,
- const float *out_grad,
- const uint8_t *mask, int row_size,
- int dim, float ratio,
- cudaStream_t stream);
-
-/**
- * @brief fused bias, activation, and dropout at the end of first ffn
- *
- * @thread
- * gridDim.x = hidden_size / 8
- * blockDim.x = 8
- * blockDim.y = 1024 / 8 = 128
- *
- * @tparam act_type activation function, like kRelu, kGelu
- * @param total_count total elements
- * @param ratio drop ratio
- * @param out [batch_size, seq_len, hidden_size], float and __half
- * @param in [batch_size, seq_len, hidden_size], float and __half
- * @param mask [batch_size, seq_len, hidden_size], uint8 type
- * @param bias [hidden_size], ffn bias
- * @param seed seed to curand
- * @param hidden_size
- * @return void
- */
-template
-__global__ void ls_dropout_act_bias_kernel(
- const int total_count, const float ratio, float *__restrict__ out,
- const float *__restrict__ in, uint8_t *__restrict__ mask,
- const float *__restrict__ bias, const int seed, const int hidden_size) {
- const float scale = 1.f / (1.f - ratio);
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 4 >= total_count) return;
-
- curandStatePhilox4_32_10_t state;
- curand_init(seed, i, 0, &state);
- uint8_t m[4];
-
- float4 *out4 = reinterpret_cast(out);
- const float4 *data4 = reinterpret_cast(in);
- const float4 *bias4 = reinterpret_cast(bias);
- uint32_t *mask4 = reinterpret_cast(mask);
- float4 rand = curand_uniform4(&state);
-
- m[0] = (uint8_t)(rand.x > ratio);
- m[1] = (uint8_t)(rand.y > ratio);
- m[2] = (uint8_t)(rand.z > ratio);
- m[3] = (uint8_t)(rand.w > ratio);
-
- int bias_i = i % (hidden_size >> 2);
- uint32_t *m4 = reinterpret_cast(m);
- mask4[i] = m4[0];
- const float4 input4 = data4[i];
- const float4 b4 = __ldg(&bias4[bias_i]);
- float4 output4;
-
- output4.x =
- activation_kernel(input4.x + b4.x) * scale * m[0];
- output4.y =
- activation_kernel(input4.y + b4.y) * scale * m[1];
- output4.z =
- activation_kernel(input4.z + b4.z) * scale * m[2];
- output4.w =
- activation_kernel(input4.w + b4.w) * scale * m[3];
-
- out4[i] = output4;
-}
-
-template
-__global__ void ls_dropout_act_bias_kernel(
- const int total_count, const float ratio, __half *__restrict__ out,
- const __half *__restrict__ in, uint8_t *__restrict__ mask,
- const __half *__restrict__ bias, const int seed, const int hidden_size) {
- const float scale = 1.f / (1.f - ratio);
-
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 8 >= total_count) return;
-
- curandStatePhilox4_32_10_t state;
- curand_init(seed, i, 0, &state);
-
- const float4 *vals_float4 = reinterpret_cast(in);
- float4 *outs_float4 = reinterpret_cast(out);
- const float4 *bias4 = reinterpret_cast(bias);
- uint64_t *mask8 = reinterpret_cast(mask);
-
- uint8_t m[8];
- float4 rand = curand_uniform4(&state);
- m[0] = (uint8_t)(rand.x > ratio);
- m[1] = (uint8_t)(rand.y > ratio);
- m[2] = (uint8_t)(rand.z > ratio);
- m[3] = (uint8_t)(rand.w > ratio);
- rand = curand_uniform4(&state);
- m[4] = (uint8_t)(rand.x > ratio);
- m[5] = (uint8_t)(rand.y > ratio);
- m[6] = (uint8_t)(rand.z > ratio);
- m[7] = (uint8_t)(rand.w > ratio);
- uint64_t *m8 = reinterpret_cast(m);
- mask8[i] = *m8;
-
- int bias_i = i % (hidden_size >> 3);
- float4 val_float4 = vals_float4[i];
- const float4 b4 = __ldg(&bias4[bias_i]);
- float4 out_float4;
-
- __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
- __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
- const __half2 *b_half2 = reinterpret_cast(&b4);
-
- __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]);
- __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]);
- __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]);
- __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]);
- out_half2[0] = __hmul2(
- activation_kernel(__hadd2(val_half2[0], b_half2[0])),
- scale_mask_1);
- out_half2[1] = __hmul2(
- activation_kernel(__hadd2(val_half2[1], b_half2[1])),
- scale_mask_2);
- out_half2[2] = __hmul2(
- activation_kernel(__hadd2(val_half2[2], b_half2[2])),
- scale_mask_3);
- out_half2[3] = __hmul2(
- activation_kernel(__hadd2(val_half2[3], b_half2[3])),
- scale_mask_4);
- outs_float4[i] = out_float4;
-}
-
-template <>
-void launch_ls_dropout_act_bias(
- float *out, const float *vals, uint8_t *mask, const float *bias,
- int total_count, int dim, float ratio, cudaStream_t stream) {
- int grid_dim = total_count >> 10;
- ls_dropout_act_bias_kernel
- <<>>(
- total_count, ratio, out, vals, mask, bias,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count(),
- dim);
-}
-
-template <>
-void launch_ls_dropout_act_bias(
- __half *out, const __half *vals, uint8_t *mask, const __half *bias,
- int total_count, int dim, float ratio, cudaStream_t stream) {
- int grid_dim = total_count >> 11;
- ls_dropout_act_bias_kernel
- <<>>(
- total_count, ratio, out, vals, mask, bias,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count(),
- dim);
-}
-
-template <>
-void launch_ls_dropout_act_bias(
- float *out, const float *vals, uint8_t *mask, const float *bias,
- int total_count, int dim, float ratio, cudaStream_t stream) {
- int grid_dim = total_count >> 10;
- ls_dropout_act_bias_kernel
- <<>>(
- total_count, ratio, out, vals, mask, bias,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count(),
- dim);
-}
-
-template <>
-void launch_ls_dropout_act_bias(
- __half *out, const __half *vals, uint8_t *mask, const __half *bias,
- int total_count, int dim, float ratio, cudaStream_t stream) {
- int grid_dim = total_count >> 11;
- ls_dropout_act_bias_kernel
- <<>>(
- total_count, ratio, out, vals, mask, bias,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count(),
- dim);
-}
-
-/**
- * @brief fused bias, activation, and dropout backward
- *
- * @thread
- * gridDim.x = total_count / 1024
- * blockDim.x = 1024
- *
- * @tparam act_type kRelu
- * @param row_size batch_size * seq_len
- * @param ratio dropout ratio
- * @param in_grad [batch_size, seq_len, hidden_size], input grad
- * @param bias_grad [hidden_size], bias grad
- * @param out_grad [batch_size, seq_len, hidden_size], output grad
- * @param mask [batch_size, seq_len, hidden_size], dropout mask
- * @param hidden_size
- * @return void
- */
-template
-__global__ void ls_dropout_act_bias_bwd_kernel(
- const int row_size, const float ratio, T *in_grad,
- T *__restrict__ bias_grad, const T *__restrict__ input,
- const T *__restrict__ bias, const T *out_grad,
- const uint8_t *__restrict__ mask, const int hidden_size) {
- const float scale = 1.f / (1.f - ratio);
- __shared__ float tile[WARP_SIZE][WARP_SIZE + 1];
-
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile