diff --git a/.circleci/build_docs/build_docs.sh b/.circleci/build_docs/build_docs.sh new file mode 100644 index 0000000000000000000000000000000000000000..2864a72974f0af8dfafa8f7790b91848d78367de --- /dev/null +++ b/.circleci/build_docs/build_docs.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash + +set -ex +# shellcheck disable=SC1091 +source ./packaging/pkg_helpers.bash +export NO_CUDA_PACKAGE=1 +setup_env 0.8.0 +setup_wheel_python + +pushd docs +pip install -r requirements.txt +make html +popd diff --git a/.circleci/build_docs/commit_docs.sh b/.circleci/build_docs/commit_docs.sh new file mode 100644 index 0000000000000000000000000000000000000000..59374dce37ab33aefc5ec5acfe659dc4f89ef551 --- /dev/null +++ b/.circleci/build_docs/commit_docs.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash + +set -ex + + +if [ "$2" == "" ]; then + echo call as "$0" "" "" + echo where src is the root of the built documentation git checkout and + echo branch should be "main" or "1.7" or so + exit 1 +fi + +src=$1 +target=$2 + +echo "committing docs from ${src} to ${target}" + +pushd "${src}" +git checkout gh-pages +mkdir -p ./"${target}" +rm -rf ./"${target}"/* +cp -r "${src}/docs/build/html/"* ./"$target" +if [ "${target}" == "main" ]; then + mkdir -p ./_static + rm -rf ./_static/* + cp -r "${src}/docs/build/html/_static/"* ./_static + git add --all ./_static || true +fi +git add --all ./"${target}" || true +git config user.email "soumith+bot@pytorch.org" +git config user.name "pytorchbot" +# If there aren't changes, don't make a commit; push is no-op +git commit -m "auto-generating sphinx docs" || true +git remote add https https://github.com/pytorch/audio.git +git push -u https gh-pages diff --git a/.circleci/build_docs/install_wheels.sh b/.circleci/build_docs/install_wheels.sh new file mode 100644 index 0000000000000000000000000000000000000000..3e151540926d175a271a6c2faa712583bbabadc9 --- /dev/null +++ b/.circleci/build_docs/install_wheels.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash + +set -ex + +# shellcheck disable=SC1091 +source ./packaging/pkg_helpers.bash +export NO_CUDA_PACKAGE=1 +setup_env 0.8.0 +setup_wheel_python +setup_pip_pytorch_version +# pytorch is already installed +pip install --no-deps ~/workspace/torchaudio* diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..55c75d5420e487af4da760925b52dbf2bf7f8370 --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,3589 @@ +version: 2.1 + +# How to test the Linux jobs: +# - Install CircleCI local CLI: https://circleci.com/docs/2.0/local-cli/ +# - circleci config process .circleci/config.yml > gen.yml && circleci local execute -c gen.yml --job binary_linux_wheel_py3.8 +# - Replace binary_linux_wheel_py3.8 with the name of the job you want to test. +# Job names are 'name:' key. + +executors: + windows-cpu: + machine: + resource_class: windows.xlarge + image: windows-server-2019-vs2019:stable + shell: bash.exe + + windows-gpu: + machine: + resource_class: windows.gpu.nvidia.medium + image: windows-server-2019-nvidia:stable + shell: bash.exe + +commands: + generate_cache_key: + description: "Generates a cache key file that changes daily" + steps: + - run: + name: Generate cache key + command: echo "$(date +"%Y-%m-%d")" > .cachekey + designate_upload_channel: + description: "inserts the correct upload channel into ${BASH_ENV}" + steps: + - run: + name: adding UPLOAD_CHANNEL to BASH_ENV + command: | + # Hardcoded for release branch + echo "export UPLOAD_CHANNEL=test" >> ${BASH_ENV} + install_build_tools_macos: + description: "installs tools required to build torchaudio" + steps: + - run: + name: Install build tools + command: HOMEBREW_NO_AUTO_UPDATE=1 brew install pkg-config wget + # Disable brew auto update which is very slow + load_conda_channel_flags: + description: "Determines whether we need extra conda channels" + steps: + - run: + name: Adding CONDA_CHANNEL_FLAGS to BASH_ENV + command: | + CONDA_CHANNEL_FLAGS="" + # formerly used to add conda-forge flags for Python 3.9, reserving the mechanism for future python upgrades + windows_install_cuda: + description: "Install desired CUDA version on Windows runners" + steps: + - run: + name: Install CUDA + command: | + packaging/windows/internal/cuda_install.bat + +binary_common: &binary_common + parameters: + # Edit these defaults to do a release + build_version: + description: "version number of release binary; by default, build a nightly" + type: string + default: "0.10.0" + pytorch_version: + description: "PyTorch version to build against; by default, use a nightly" + type: string + default: "1.10.0" + # Don't edit these + python_version: + description: "Python version to build against (e.g., 3.8)" + type: string + cuda_version: + description: "CUDA version to build against (e.g., cpu, cu101)" + type: string + default: "cpu" + wheel_docker_image: + description: "Wheel only: what docker image to use" + type: string + default: "pytorch/manylinux-cuda102" + conda_docker_image: + description: "Conda only: what docker image to use" + type: string + default: "pytorch/conda-builder:cuda102" + environment: &environment + PYTHON_VERSION: << parameters.python_version >> + BUILD_VERSION: << parameters.build_version >> + PYTORCH_VERSION: << parameters.pytorch_version >> + CU_VERSION: << parameters.cuda_version >> + +smoke_test_common: &smoke_test_common + <<: *binary_common + docker: + - image: pytorch/torchaudio_unittest_base:smoke_test-20211019 + resource_class: large + +jobs: + circleci_consistency: + docker: + - image: cimg/python:3.8 + steps: + - checkout + - run: + command: | + pip install --user --progress-bar off jinja2 pyyaml + python .circleci/regenerate.py + git diff --exit-code || (echo ".circleci/config.yml not in sync with config.yml.in! Run .circleci/regenerate.py to update config"; exit 1) + + download_third_parties_nix: + docker: + - image: "pytorch/torchaudio_unittest_base:manylinux" + resource_class: small + steps: + - checkout + - generate_cache_key + - restore_cache: + + keys: + - tp-nix-v2-{{ checksum ".cachekey" }} + + - run: + command: | + mkdir -p third_party/sox/archives/ + wget --no-clobber --directory-prefix=third_party/sox/archives/ $(awk '/URL /{print $2}' third_party/sox/CMakeLists.txt) + - save_cache: + + key: tp-nix-v2-{{ checksum ".cachekey" }} + + paths: + - third_party/sox/archives + - persist_to_workspace: + root: third_party + paths: + - sox/archives + + binary_linux_wheel: + <<: *binary_common + docker: + - image: << parameters.wheel_docker_image >> + resource_class: 2xlarge+ + steps: + - checkout + - designate_upload_channel + - attach_workspace: + at: third_party + - run: packaging/build_wheel.sh + - store_artifacts: + path: dist + - persist_to_workspace: + root: dist + paths: + - "*" + + binary_linux_conda: + <<: *binary_common + docker: + - image: "<< parameters.conda_docker_image >>" + resource_class: 2xlarge+ + steps: + - checkout + - load_conda_channel_flags + - attach_workspace: + at: third_party + - run: packaging/build_conda.sh + - store_artifacts: + path: /opt/conda/conda-bld/linux-64 + - persist_to_workspace: + root: /opt/conda + paths: + - "conda-bld/*" + + binary_macos_wheel: + <<: *binary_common + macos: + xcode: "12.0" + steps: + - checkout + - install_build_tools_macos + - designate_upload_channel + - load_conda_channel_flags + - attach_workspace: + at: third_party + - run: + # Cannot easily deduplicate this as source'ing activate + # will set environment variables which we need to propagate + # to build_wheel.sh + command: | + curl -o conda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + sh conda.sh -b + source $HOME/miniconda3/bin/activate + packaging/build_wheel.sh + - store_artifacts: + path: dist + - persist_to_workspace: + root: dist + paths: + - "*" + + binary_macos_conda: + <<: *binary_common + macos: + xcode: "12.0" + steps: + - checkout + - install_build_tools_macos + - load_conda_channel_flags + - attach_workspace: + at: third_party + - run: + command: | + curl -o conda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + sh conda.sh -b + source $HOME/miniconda3/bin/activate + conda install -yq conda-build + packaging/build_conda.sh + - store_artifacts: + path: /Users/distiller/miniconda3/conda-bld/osx-64 + - persist_to_workspace: + root: /Users/distiller/miniconda3 + paths: + - "conda-bld/*" + + binary_windows_wheel: + <<: *binary_common + executor: + name: windows-cpu + steps: + - checkout + - designate_upload_channel + - load_conda_channel_flags + - windows_install_cuda + - run: + name: Build wheel packages + command: | + set -ex + eval "$('/C/tools/miniconda3/Scripts/conda.exe' 'shell.bash' 'hook')" + conda activate base + bash packaging/build_wheel.sh + - store_artifacts: + path: dist + - persist_to_workspace: + root: dist + paths: + - "*" + + binary_windows_conda: + <<: *binary_common + executor: + name: windows-cpu + steps: + - checkout + - load_conda_channel_flags + - windows_install_cuda + - run: + name: Build conda packages + no_output_timeout: 20m + command: | + set -ex + eval "$('/C/tools/miniconda3/Scripts/conda.exe' 'shell.bash' 'hook')" + conda activate base + conda install -yq conda-build "conda-package-handling!=1.5.0" + # cudatoolkit >= 11 isn't available for windows in the nvidia channel + if [[ "${CU_VERSION}" =~ cu11.* ]]; then + export CONDA_CHANNEL_FLAGS="-c conda-forge" + fi + bash packaging/build_conda.sh + - store_artifacts: + path: C:/tools/miniconda3/conda-bld/win-64 + - persist_to_workspace: + root: C:/tools/miniconda3 + paths: + - "conda-bld/*" + + # Requires org-member context + binary_conda_upload: + docker: + - image: continuumio/miniconda + steps: + - attach_workspace: + at: ~/workspace + - designate_upload_channel + - run: + command: | + # Prevent credential from leaking + conda install -yq anaconda-client + set -x + anaconda -t "${CONDA_PYTORCHBOT_TOKEN}" upload ~/workspace/conda-bld/*/*.tar.bz2 -u "pytorch-${UPLOAD_CHANNEL}" --label main --no-progress --force + + # Requires org-member context + binary_wheel_upload: + parameters: + subfolder: + description: "What whl subfolder to upload to, e.g., blank or cu100/ (trailing slash is important)" + type: string + docker: + - image: cimg/python:3.8 + steps: + - attach_workspace: + at: ~/workspace + - checkout + - designate_upload_channel + - run: + command: | + pip install --user awscli + export PATH="$HOME/.local/bin:$PATH" + # Prevent credential from leaking + set +x + export AWS_ACCESS_KEY_ID="${PYTORCH_BINARY_AWS_ACCESS_KEY_ID}" + export AWS_SECRET_ACCESS_KEY="${PYTORCH_BINARY_AWS_SECRET_ACCESS_KEY}" + set -x + for pkg in ~/workspace/*.whl; do + aws s3 cp "$pkg" "s3://pytorch/whl/${UPLOAD_CHANNEL}/<< parameters.subfolder >>" --acl public-read + done + + smoke_test_linux_conda: + <<: *smoke_test_common + steps: + - attach_workspace: + at: ~/workspace + - designate_upload_channel + - load_conda_channel_flags + - run: + name: install binaries + command: | + set -x + source /usr/local/etc/profile.d/conda.sh && conda activate python${PYTHON_VERSION} + conda install -v -y -c pytorch-${UPLOAD_CHANNEL} pytorch cpuonly + conda install -v -y -c file://$HOME/workspace/conda-bld torchaudio + - run: + name: smoke test + command: | + source /usr/local/etc/profile.d/conda.sh && conda activate python${PYTHON_VERSION} + python -c "import torchaudio" + + smoke_test_linux_conda_gpu: + <<: *smoke_test_common + steps: + - attach_workspace: + at: ~/workspace + - designate_upload_channel + - load_conda_channel_flags + - run: + name: install binaries + command: | + set -x + source /usr/local/etc/profile.d/conda.sh && conda activate python${PYTHON_VERSION} + conda install -v -y -c pytorch-${UPLOAD_CHANNEL} pytorch cudatoolkit=${CU_VERSION:2:2}.${CU_VERSION:4} -c conda-forge + conda install -v -y -c file://$HOME/workspace/conda-bld torchaudio + - run: + name: smoke test + command: | + source /usr/local/etc/profile.d/conda.sh && conda activate python${PYTHON_VERSION} + python -c "import torchaudio" + + smoke_test_linux_pip: + <<: *smoke_test_common + steps: + - attach_workspace: + at: ~/workspace + - designate_upload_channel + - load_conda_channel_flags + - run: + name: install binaries + command: | + set -x + source /usr/local/etc/profile.d/conda.sh && conda activate python${PYTHON_VERSION} + pip install $(ls ~/workspace/torchaudio*.whl) -f "https://download.pytorch.org/whl/${UPLOAD_CHANNEL}/${CU_VERSION}/torch_${UPLOAD_CHANNEL}.html" + - run: + name: smoke test + command: | + source /usr/local/etc/profile.d/conda.sh && conda activate python${PYTHON_VERSION} + python -c "import torchaudio" + + smoke_test_windows_conda: + <<: *binary_common + executor: + name: windows-cpu + steps: + - attach_workspace: + at: ~/workspace + - designate_upload_channel + - load_conda_channel_flags + - run: + name: install binaries + command: | + set -x + eval "$('/C/tools/miniconda3/Scripts/conda.exe' 'shell.bash' 'hook')" + conda env remove -n python${PYTHON_VERSION} || true + conda create -yn python${PYTHON_VERSION} python=${PYTHON_VERSION} + conda activate python${PYTHON_VERSION} + conda install -v -y -c pytorch-${UPLOAD_CHANNEL} pytorch cpuonly + conda install -v -y $(ls ~/workspace/torchaudio*.tar.bz2) + - run: + name: smoke test + command: | + eval "$('/C/tools/miniconda3/Scripts/conda.exe' 'shell.bash' 'hook')" + conda activate python${PYTHON_VERSION} + python -c "import torchaudio" + + smoke_test_windows_pip: + <<: *binary_common + executor: + name: windows-cpu + steps: + - attach_workspace: + at: ~/workspace + - designate_upload_channel + - load_conda_channel_flags + - run: + name: install binaries + command: | + set -x + eval "$('/C/tools/miniconda3/Scripts/conda.exe' 'shell.bash' 'hook')" + conda env remove -n python${PYTHON_VERSION} || true + conda create -yn python${PYTHON_VERSION} python=${PYTHON_VERSION} + conda activate python${PYTHON_VERSION} + pip install $(ls ~/workspace/torchaudio*.whl) -f "https://download.pytorch.org/whl/${UPLOAD_CHANNEL}/torch_${UPLOAD_CHANNEL}.html" + - run: + name: smoke test + command: | + eval "$('/C/tools/miniconda3/Scripts/conda.exe' 'shell.bash' 'hook')" + conda activate python${PYTHON_VERSION} + python -c "import torchaudio" + + smoke_test_docker_image_build: + machine: + image: ubuntu-1604:201903-01 + resource_class: large + environment: + image_name: torchaudio/smoke_test + steps: + - checkout + - run: + name: build_docker image + no_output_timeout: "1h" + command: | + cd .circleci/smoke_test/docker && docker build . -t ${image_name}:${CIRCLE_WORKFLOW_ID} + - run: + name: upload docker image + no_output_timeout: "1h" + command: | + set +x + export AWS_ACCESS_KEY_ID=${ECR_AWS_ACCESS_KEY} + export AWS_SECRET_ACCESS_KEY=${ECR_AWS_SECRET_ACCESS_KEY} + eval $(aws ecr get-login --region us-east-1 --no-include-email) + set -x + docker tag ${image_name}:${CIRCLE_WORKFLOW_ID} 308535385114.dkr.ecr.us-east-1.amazonaws.com/${image_name}:${CIRCLE_WORKFLOW_ID} + docker tag ${image_name}:${CIRCLE_WORKFLOW_ID} 308535385114.dkr.ecr.us-east-1.amazonaws.com/${image_name}:latest + docker push 308535385114.dkr.ecr.us-east-1.amazonaws.com/${image_name}:${CIRCLE_WORKFLOW_ID} + docker push 308535385114.dkr.ecr.us-east-1.amazonaws.com/${image_name}:latest + + unittest_linux_cpu: + <<: *binary_common + docker: + - image: pytorch/torchaudio_unittest_base:manylinux-20210121 + resource_class: 2xlarge+ + steps: + - checkout + - attach_workspace: + at: third_party + - designate_upload_channel + - load_conda_channel_flags + - run: + name: Setup + command: .circleci/unittest/linux/scripts/setup_env.sh + - run: + name: Install torchaudio + command: .circleci/unittest/linux/scripts/install.sh + - run: + name: Run tests + command: .circleci/unittest/linux/scripts/run_test.sh + - store_test_results: + path: test-results + - store_artifacts: + path: test/htmlcov + unittest_linux_gpu: + <<: *binary_common + machine: + image: ubuntu-1604-cuda-10.1:201909-23 + resource_class: gpu.small + environment: + <<: *environment + image_name: pytorch/torchaudio_unittest_base:manylinux-cuda10.2-cudnn8-20210623 + steps: + - checkout + - attach_workspace: + at: third_party + - designate_upload_channel + - load_conda_channel_flags + - run: + name: Pull Docker image + command: docker pull --quiet "${image_name}" + - run: + name: Setup + command: docker run -t --gpus all -e PYTHON_VERSION -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/setup_env.sh + - run: + name: Install torchaudio + command: docker run -t --gpus all -e UPLOAD_CHANNEL -e CONDA_CHANNEL_FLAGS -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/install.sh + - run: + name: Run tests + command: docker run -t --gpus all -v $PWD:$PWD -w $PWD -e "TORCHAUDIO_TEST_FORCE_CUDA=1" "${image_name}" .circleci/unittest/linux/scripts/run_test.sh + - store_test_results: + path: test-results + - store_artifacts: + path: test/htmlcov + + unittest_windows_cpu: + <<: *binary_common + executor: + name: windows-cpu + steps: + - checkout + - designate_upload_channel + - load_conda_channel_flags + - run: + name: Setup + command: .circleci/unittest/windows/scripts/setup_env.sh + - run: + name: Install torchaudio + command: .circleci/unittest/windows/scripts/install.sh + - run: + name: Run tests + command: .circleci/unittest/windows/scripts/run_test.sh + - store_test_results: + path: test-results + - store_artifacts: + path: test/htmlcov + + unittest_windows_gpu: + <<: *binary_common + executor: + name: windows-gpu + environment: + <<: *environment + CUDA_VERSION: "10.2" + TORCHAUDIO_TEST_FORCE_CUDA: 1 + steps: + - checkout + - designate_upload_channel + - load_conda_channel_flags + - run: + name: Setup + command: .circleci/unittest/windows/scripts/setup_env.sh + - run: + name: Install CUDA + command: packaging/windows/internal/cuda_install.bat + - run: + name: Update CUDA driver + command: packaging/windows/internal/driver_update.bat + - run: + name: Install torchaudio + command: .circleci/unittest/windows/scripts/install.sh + - run: + name: Run tests + command: .circleci/unittest/windows/scripts/run_test.sh + - store_test_results: + path: test-results + - store_artifacts: + path: test/htmlcov + + unittest_macos_cpu: + <<: *binary_common + macos: + xcode: "12.0" + resource_class: large + steps: + - checkout + - install_build_tools_macos + - load_conda_channel_flags + - attach_workspace: + at: third_party + - designate_upload_channel + - run: + name: Setup + command: .circleci/unittest/linux/scripts/setup_env.sh + - run: + name: Install torchaudio + command: .circleci/unittest/linux/scripts/install.sh + - run: + name: Run tests + command: .circleci/unittest/linux/scripts/run_test.sh + - store_test_results: + path: test-results + - store_artifacts: + path: test/htmlcov + + stylecheck: + <<: *binary_common + docker: + - image: "pytorch/torchaudio_unittest_base:manylinux" + resource_class: medium + steps: + - checkout + - designate_upload_channel + - load_conda_channel_flags + - run: + name: Setup + command: .circleci/unittest/linux/scripts/setup_env.sh + - run: + name: Run style check + command: .circleci/unittest/linux/scripts/run_style_checks.sh + + build_docs: + <<: *binary_common + docker: + - image: "pytorch/manylinux-cuda100" + resource_class: 2xlarge+ + steps: + - attach_workspace: + at: ~/workspace + - checkout + - designate_upload_channel + - load_conda_channel_flags + - run: + name: Install pytorch-audio + command: .circleci/build_docs/install_wheels.sh + - run: + name: Build docs + command: .circleci/build_docs/build_docs.sh + - persist_to_workspace: + root: ./ + paths: + - "*" + - store_artifacts: + path: ./docs/build/html + destination: docs + + upload_docs: + <<: *binary_common + docker: + - image: "pytorch/manylinux-cuda100" + resource_class: 2xlarge+ + steps: + - attach_workspace: + at: ~/workspace + - run: + name: Generate netrc + command: | + # set credentials for https pushing + # requires the org-member context + cat > ~/.netrc \< gen.yml && circleci local execute -c gen.yml --job binary_linux_wheel_py3.8 +# - Replace binary_linux_wheel_py3.8 with the name of the job you want to test. +# Job names are 'name:' key. + +executors: + windows-cpu: + machine: + resource_class: windows.xlarge + image: windows-server-2019-vs2019:stable + shell: bash.exe + + windows-gpu: + machine: + resource_class: windows.gpu.nvidia.medium + image: windows-server-2019-nvidia:stable + shell: bash.exe + +commands: + generate_cache_key: + description: "Generates a cache key file that changes daily" + steps: + - run: + name: Generate cache key + command: echo "$(date +"%Y-%m-%d")" > .cachekey + designate_upload_channel: + description: "inserts the correct upload channel into ${BASH_ENV}" + steps: + - run: + name: adding UPLOAD_CHANNEL to BASH_ENV + command: | + # Hardcoded for release branch + echo "export UPLOAD_CHANNEL=test" >> ${BASH_ENV} + install_build_tools_macos: + description: "installs tools required to build torchaudio" + steps: + - run: + name: Install build tools + command: HOMEBREW_NO_AUTO_UPDATE=1 brew install pkg-config wget + # Disable brew auto update which is very slow + load_conda_channel_flags: + description: "Determines whether we need extra conda channels" + steps: + - run: + name: Adding CONDA_CHANNEL_FLAGS to BASH_ENV + command: | + CONDA_CHANNEL_FLAGS="" + # formerly used to add conda-forge flags for Python 3.9, reserving the mechanism for future python upgrades + windows_install_cuda: + description: "Install desired CUDA version on Windows runners" + steps: + - run: + name: Install CUDA + command: | + packaging/windows/internal/cuda_install.bat + +binary_common: &binary_common + parameters: + # Edit these defaults to do a release + build_version: + description: "version number of release binary; by default, build a nightly" + type: string + default: "0.10.0" + pytorch_version: + description: "PyTorch version to build against; by default, use a nightly" + type: string + default: "1.10.0" + # Don't edit these + python_version: + description: "Python version to build against (e.g., 3.8)" + type: string + cuda_version: + description: "CUDA version to build against (e.g., cpu, cu101)" + type: string + default: "cpu" + wheel_docker_image: + description: "Wheel only: what docker image to use" + type: string + default: "pytorch/manylinux-cuda102" + conda_docker_image: + description: "Conda only: what docker image to use" + type: string + default: "pytorch/conda-builder:cuda102" + environment: &environment + PYTHON_VERSION: << parameters.python_version >> + BUILD_VERSION: << parameters.build_version >> + PYTORCH_VERSION: << parameters.pytorch_version >> + CU_VERSION: << parameters.cuda_version >> + +smoke_test_common: &smoke_test_common + <<: *binary_common + docker: + - image: pytorch/torchaudio_unittest_base:smoke_test-20211019 + resource_class: large + +jobs: + circleci_consistency: + docker: + - image: cimg/python:3.8 + steps: + - checkout + - run: + command: | + pip install --user --progress-bar off jinja2 pyyaml + python .circleci/regenerate.py + git diff --exit-code || (echo ".circleci/config.yml not in sync with config.yml.in! Run .circleci/regenerate.py to update config"; exit 1) + + download_third_parties_nix: + docker: + - image: "pytorch/torchaudio_unittest_base:manylinux" + resource_class: small + steps: + - checkout + - generate_cache_key + - restore_cache: + {% raw %} + keys: + - tp-nix-v2-{{ checksum ".cachekey" }} + {% endraw %} + - run: + command: | + mkdir -p third_party/sox/archives/ + wget --no-clobber --directory-prefix=third_party/sox/archives/ $(awk '/URL /{print $2}' third_party/sox/CMakeLists.txt) + - save_cache: + {% raw %} + key: tp-nix-v2-{{ checksum ".cachekey" }} + {% endraw %} + paths: + - third_party/sox/archives + - persist_to_workspace: + root: third_party + paths: + - sox/archives + + binary_linux_wheel: + <<: *binary_common + docker: + - image: << parameters.wheel_docker_image >> + resource_class: 2xlarge+ + steps: + - checkout + - designate_upload_channel + - attach_workspace: + at: third_party + - run: packaging/build_wheel.sh + - store_artifacts: + path: dist + - persist_to_workspace: + root: dist + paths: + - "*" + + binary_linux_conda: + <<: *binary_common + docker: + - image: "<< parameters.conda_docker_image >>" + resource_class: 2xlarge+ + steps: + - checkout + - load_conda_channel_flags + - attach_workspace: + at: third_party + - run: packaging/build_conda.sh + - store_artifacts: + path: /opt/conda/conda-bld/linux-64 + - persist_to_workspace: + root: /opt/conda + paths: + - "conda-bld/*" + + binary_macos_wheel: + <<: *binary_common + macos: + xcode: "12.0" + steps: + - checkout + - install_build_tools_macos + - designate_upload_channel + - load_conda_channel_flags + - attach_workspace: + at: third_party + - run: + # Cannot easily deduplicate this as source'ing activate + # will set environment variables which we need to propagate + # to build_wheel.sh + command: | + curl -o conda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + sh conda.sh -b + source $HOME/miniconda3/bin/activate + packaging/build_wheel.sh + - store_artifacts: + path: dist + - persist_to_workspace: + root: dist + paths: + - "*" + + binary_macos_conda: + <<: *binary_common + macos: + xcode: "12.0" + steps: + - checkout + - install_build_tools_macos + - load_conda_channel_flags + - attach_workspace: + at: third_party + - run: + command: | + curl -o conda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + sh conda.sh -b + source $HOME/miniconda3/bin/activate + conda install -yq conda-build + packaging/build_conda.sh + - store_artifacts: + path: /Users/distiller/miniconda3/conda-bld/osx-64 + - persist_to_workspace: + root: /Users/distiller/miniconda3 + paths: + - "conda-bld/*" + + binary_windows_wheel: + <<: *binary_common + executor: + name: windows-cpu + steps: + - checkout + - designate_upload_channel + - load_conda_channel_flags + - windows_install_cuda + - run: + name: Build wheel packages + command: | + set -ex + eval "$('/C/tools/miniconda3/Scripts/conda.exe' 'shell.bash' 'hook')" + conda activate base + bash packaging/build_wheel.sh + - store_artifacts: + path: dist + - persist_to_workspace: + root: dist + paths: + - "*" + + binary_windows_conda: + <<: *binary_common + executor: + name: windows-cpu + steps: + - checkout + - load_conda_channel_flags + - windows_install_cuda + - run: + name: Build conda packages + no_output_timeout: 20m + command: | + set -ex + eval "$('/C/tools/miniconda3/Scripts/conda.exe' 'shell.bash' 'hook')" + conda activate base + conda install -yq conda-build "conda-package-handling!=1.5.0" + # cudatoolkit >= 11 isn't available for windows in the nvidia channel + if [[ "${CU_VERSION}" =~ cu11.* ]]; then + export CONDA_CHANNEL_FLAGS="-c conda-forge" + fi + bash packaging/build_conda.sh + - store_artifacts: + path: C:/tools/miniconda3/conda-bld/win-64 + - persist_to_workspace: + root: C:/tools/miniconda3 + paths: + - "conda-bld/*" + + # Requires org-member context + binary_conda_upload: + docker: + - image: continuumio/miniconda + steps: + - attach_workspace: + at: ~/workspace + - designate_upload_channel + - run: + command: | + # Prevent credential from leaking + conda install -yq anaconda-client + set -x + anaconda -t "${CONDA_PYTORCHBOT_TOKEN}" upload ~/workspace/conda-bld/*/*.tar.bz2 -u "pytorch-${UPLOAD_CHANNEL}" --label main --no-progress --force + + # Requires org-member context + binary_wheel_upload: + parameters: + subfolder: + description: "What whl subfolder to upload to, e.g., blank or cu100/ (trailing slash is important)" + type: string + docker: + - image: cimg/python:3.8 + steps: + - attach_workspace: + at: ~/workspace + - checkout + - designate_upload_channel + - run: + command: | + pip install --user awscli + export PATH="$HOME/.local/bin:$PATH" + # Prevent credential from leaking + set +x + export AWS_ACCESS_KEY_ID="${PYTORCH_BINARY_AWS_ACCESS_KEY_ID}" + export AWS_SECRET_ACCESS_KEY="${PYTORCH_BINARY_AWS_SECRET_ACCESS_KEY}" + set -x + for pkg in ~/workspace/*.whl; do + aws s3 cp "$pkg" "s3://pytorch/whl/${UPLOAD_CHANNEL}/<< parameters.subfolder >>" --acl public-read + done + + smoke_test_linux_conda: + <<: *smoke_test_common + steps: + - attach_workspace: + at: ~/workspace + - designate_upload_channel + - load_conda_channel_flags + - run: + name: install binaries + command: | + set -x + source /usr/local/etc/profile.d/conda.sh && conda activate python${PYTHON_VERSION} + conda install -v -y -c pytorch-${UPLOAD_CHANNEL} pytorch cpuonly + conda install -v -y -c file://$HOME/workspace/conda-bld torchaudio + - run: + name: smoke test + command: | + source /usr/local/etc/profile.d/conda.sh && conda activate python${PYTHON_VERSION} + python -c "import torchaudio" + + smoke_test_linux_conda_gpu: + <<: *smoke_test_common + steps: + - attach_workspace: + at: ~/workspace + - designate_upload_channel + - load_conda_channel_flags + - run: + name: install binaries + command: | + set -x + source /usr/local/etc/profile.d/conda.sh && conda activate python${PYTHON_VERSION} + conda install -v -y -c pytorch-${UPLOAD_CHANNEL} pytorch cudatoolkit=${CU_VERSION:2:2}.${CU_VERSION:4} -c conda-forge + conda install -v -y -c file://$HOME/workspace/conda-bld torchaudio + - run: + name: smoke test + command: | + source /usr/local/etc/profile.d/conda.sh && conda activate python${PYTHON_VERSION} + python -c "import torchaudio" + + smoke_test_linux_pip: + <<: *smoke_test_common + steps: + - attach_workspace: + at: ~/workspace + - designate_upload_channel + - load_conda_channel_flags + - run: + name: install binaries + command: | + set -x + source /usr/local/etc/profile.d/conda.sh && conda activate python${PYTHON_VERSION} + pip install $(ls ~/workspace/torchaudio*.whl) -f "https://download.pytorch.org/whl/${UPLOAD_CHANNEL}/${CU_VERSION}/torch_${UPLOAD_CHANNEL}.html" + - run: + name: smoke test + command: | + source /usr/local/etc/profile.d/conda.sh && conda activate python${PYTHON_VERSION} + python -c "import torchaudio" + + smoke_test_windows_conda: + <<: *binary_common + executor: + name: windows-cpu + steps: + - attach_workspace: + at: ~/workspace + - designate_upload_channel + - load_conda_channel_flags + - run: + name: install binaries + command: | + set -x + eval "$('/C/tools/miniconda3/Scripts/conda.exe' 'shell.bash' 'hook')" + conda env remove -n python${PYTHON_VERSION} || true + conda create -yn python${PYTHON_VERSION} python=${PYTHON_VERSION} + conda activate python${PYTHON_VERSION} + conda install -v -y -c pytorch-${UPLOAD_CHANNEL} pytorch cpuonly + conda install -v -y $(ls ~/workspace/torchaudio*.tar.bz2) + - run: + name: smoke test + command: | + eval "$('/C/tools/miniconda3/Scripts/conda.exe' 'shell.bash' 'hook')" + conda activate python${PYTHON_VERSION} + python -c "import torchaudio" + + smoke_test_windows_pip: + <<: *binary_common + executor: + name: windows-cpu + steps: + - attach_workspace: + at: ~/workspace + - designate_upload_channel + - load_conda_channel_flags + - run: + name: install binaries + command: | + set -x + eval "$('/C/tools/miniconda3/Scripts/conda.exe' 'shell.bash' 'hook')" + conda env remove -n python${PYTHON_VERSION} || true + conda create -yn python${PYTHON_VERSION} python=${PYTHON_VERSION} + conda activate python${PYTHON_VERSION} + pip install $(ls ~/workspace/torchaudio*.whl) -f "https://download.pytorch.org/whl/${UPLOAD_CHANNEL}/torch_${UPLOAD_CHANNEL}.html" + - run: + name: smoke test + command: | + eval "$('/C/tools/miniconda3/Scripts/conda.exe' 'shell.bash' 'hook')" + conda activate python${PYTHON_VERSION} + python -c "import torchaudio" + + smoke_test_docker_image_build: + machine: + image: ubuntu-1604:201903-01 + resource_class: large + environment: + image_name: torchaudio/smoke_test + steps: + - checkout + - run: + name: build_docker image + no_output_timeout: "1h" + command: | + cd .circleci/smoke_test/docker && docker build . -t ${image_name}:${CIRCLE_WORKFLOW_ID} + - run: + name: upload docker image + no_output_timeout: "1h" + command: | + set +x + export AWS_ACCESS_KEY_ID=${ECR_AWS_ACCESS_KEY} + export AWS_SECRET_ACCESS_KEY=${ECR_AWS_SECRET_ACCESS_KEY} + eval $(aws ecr get-login --region us-east-1 --no-include-email) + set -x + docker tag ${image_name}:${CIRCLE_WORKFLOW_ID} 308535385114.dkr.ecr.us-east-1.amazonaws.com/${image_name}:${CIRCLE_WORKFLOW_ID} + docker tag ${image_name}:${CIRCLE_WORKFLOW_ID} 308535385114.dkr.ecr.us-east-1.amazonaws.com/${image_name}:latest + docker push 308535385114.dkr.ecr.us-east-1.amazonaws.com/${image_name}:${CIRCLE_WORKFLOW_ID} + docker push 308535385114.dkr.ecr.us-east-1.amazonaws.com/${image_name}:latest + + unittest_linux_cpu: + <<: *binary_common + docker: + - image: pytorch/torchaudio_unittest_base:manylinux-20210121 + resource_class: 2xlarge+ + steps: + - checkout + - attach_workspace: + at: third_party + - designate_upload_channel + - load_conda_channel_flags + - run: + name: Setup + command: .circleci/unittest/linux/scripts/setup_env.sh + - run: + name: Install torchaudio + command: .circleci/unittest/linux/scripts/install.sh + - run: + name: Run tests + command: .circleci/unittest/linux/scripts/run_test.sh + - store_test_results: + path: test-results + - store_artifacts: + path: test/htmlcov + unittest_linux_gpu: + <<: *binary_common + machine: + image: ubuntu-1604-cuda-10.1:201909-23 + resource_class: gpu.small + environment: + <<: *environment + image_name: pytorch/torchaudio_unittest_base:manylinux-cuda10.2-cudnn8-20210623 + steps: + - checkout + - attach_workspace: + at: third_party + - designate_upload_channel + - load_conda_channel_flags + - run: + name: Pull Docker image + command: docker pull --quiet "${image_name}" + - run: + name: Setup + command: docker run -t --gpus all -e PYTHON_VERSION -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/setup_env.sh + - run: + name: Install torchaudio + command: docker run -t --gpus all -e UPLOAD_CHANNEL -e CONDA_CHANNEL_FLAGS -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/install.sh + - run: + name: Run tests + command: docker run -t --gpus all -v $PWD:$PWD -w $PWD -e "TORCHAUDIO_TEST_FORCE_CUDA=1" "${image_name}" .circleci/unittest/linux/scripts/run_test.sh + - store_test_results: + path: test-results + - store_artifacts: + path: test/htmlcov + + unittest_windows_cpu: + <<: *binary_common + executor: + name: windows-cpu + steps: + - checkout + - designate_upload_channel + - load_conda_channel_flags + - run: + name: Setup + command: .circleci/unittest/windows/scripts/setup_env.sh + - run: + name: Install torchaudio + command: .circleci/unittest/windows/scripts/install.sh + - run: + name: Run tests + command: .circleci/unittest/windows/scripts/run_test.sh + - store_test_results: + path: test-results + - store_artifacts: + path: test/htmlcov + + unittest_windows_gpu: + <<: *binary_common + executor: + name: windows-gpu + environment: + <<: *environment + CUDA_VERSION: "10.2" + TORCHAUDIO_TEST_FORCE_CUDA: 1 + steps: + - checkout + - designate_upload_channel + - load_conda_channel_flags + - run: + name: Setup + command: .circleci/unittest/windows/scripts/setup_env.sh + - run: + name: Install CUDA + command: packaging/windows/internal/cuda_install.bat + - run: + name: Update CUDA driver + command: packaging/windows/internal/driver_update.bat + - run: + name: Install torchaudio + command: .circleci/unittest/windows/scripts/install.sh + - run: + name: Run tests + command: .circleci/unittest/windows/scripts/run_test.sh + - store_test_results: + path: test-results + - store_artifacts: + path: test/htmlcov + + unittest_macos_cpu: + <<: *binary_common + macos: + xcode: "12.0" + resource_class: large + steps: + - checkout + - install_build_tools_macos + - load_conda_channel_flags + - attach_workspace: + at: third_party + - designate_upload_channel + - run: + name: Setup + command: .circleci/unittest/linux/scripts/setup_env.sh + - run: + name: Install torchaudio + command: .circleci/unittest/linux/scripts/install.sh + - run: + name: Run tests + command: .circleci/unittest/linux/scripts/run_test.sh + - store_test_results: + path: test-results + - store_artifacts: + path: test/htmlcov + + stylecheck: + <<: *binary_common + docker: + - image: "pytorch/torchaudio_unittest_base:manylinux" + resource_class: medium + steps: + - checkout + - designate_upload_channel + - load_conda_channel_flags + - run: + name: Setup + command: .circleci/unittest/linux/scripts/setup_env.sh + - run: + name: Run style check + command: .circleci/unittest/linux/scripts/run_style_checks.sh + + build_docs: + <<: *binary_common + docker: + - image: "pytorch/manylinux-cuda100" + resource_class: 2xlarge+ + steps: + - attach_workspace: + at: ~/workspace + - checkout + - designate_upload_channel + - load_conda_channel_flags + - run: + name: Install pytorch-audio + command: .circleci/build_docs/install_wheels.sh + - run: + name: Build docs + command: .circleci/build_docs/build_docs.sh + - persist_to_workspace: + root: ./ + paths: + - "*" + - store_artifacts: + path: ./docs/build/html + destination: docs + + upload_docs: + <<: *binary_common + docker: + - image: "pytorch/manylinux-cuda100" + resource_class: 2xlarge+ + steps: + - attach_workspace: + at: ~/workspace + - run: + name: Generate netrc + command: | + # set credentials for https pushing + # requires the org-member context + cat > ~/.netrc \<> ~/.bashrc +RUN source /usr/local/etc/profile.d/conda.sh && conda activate python3.6 && conda install -y -c conda-forge sox && conda install -y numpy +RUN source /usr/local/etc/profile.d/conda.sh && conda activate python3.7 && conda install -y -c conda-forge sox && conda install -y numpy +RUN source /usr/local/etc/profile.d/conda.sh && conda activate python3.8 && conda install -y -c conda-forge sox && conda install -y numpy +CMD [ "/bin/bash"] diff --git a/.circleci/unittest/linux/README.md b/.circleci/unittest/linux/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0a4b0e0e6335ba8751a82d952206ca4c0193ec9f --- /dev/null +++ b/.circleci/unittest/linux/README.md @@ -0,0 +1,6 @@ +This directory contains; + + - docker + Docker image definition and scripts to build and update Docker image for unittest. + - scripts + Scripts used by CircleCI to run unit tests. diff --git a/.circleci/unittest/linux/docker/.dockerignore b/.circleci/unittest/linux/docker/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..1398d409f807725a0ea6d0cd6d530ff8a4cc8dbf --- /dev/null +++ b/.circleci/unittest/linux/docker/.dockerignore @@ -0,0 +1,2 @@ +* +!scripts diff --git a/.circleci/unittest/linux/docker/.gitignore b/.circleci/unittest/linux/docker/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..7e977058ddd63a72636f7b84cdd6705628ae515c --- /dev/null +++ b/.circleci/unittest/linux/docker/.gitignore @@ -0,0 +1,2 @@ +scripts/build_third_parties.sh +Dockerfile.tmp diff --git a/.circleci/unittest/linux/docker/Dockerfile b/.circleci/unittest/linux/docker/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..c47a89634864e09015401effb269092227612c8b --- /dev/null +++ b/.circleci/unittest/linux/docker/Dockerfile @@ -0,0 +1,56 @@ +FROM ubuntu:18.04 as builder + +RUN apt update -q + +################################################################################ +# Build Kaldi +################################################################################ +RUN apt install -q -y \ + autoconf \ + automake \ + bzip2 \ + g++ \ + gfortran \ + git \ + libatlas-base-dev \ + libtool \ + make \ + python2.7 \ + python3 \ + sox \ + subversion \ + unzip \ + wget \ + zlib1g-dev + +# KALDI uses MKL as a default math library, but we are going to copy featbin binaries and dependent +# shared libraries to the final image, so we use ATLAS, which is easy to reinstall in the final image. +RUN git clone --depth 1 https://github.com/kaldi-asr/kaldi.git /opt/kaldi && \ + cd /opt/kaldi/tools && \ + make -j $(nproc) && \ + cd /opt/kaldi/src && \ + ./configure --shared --mathlib=ATLAS --use-cuda=no && \ + make featbin -j $(nproc) + +# Copy featbins and dependent libraries +ADD ./scripts /scripts +RUN bash /scripts/copy_kaldi_executables.sh /opt/kaldi /kaldi + +################################################################################ +# Build the final image +################################################################################ +FROM BASE_IMAGE +RUN apt update && apt install -y \ + g++ \ + gfortran \ + git \ + libatlas3-base \ + libsndfile1 \ + wget \ + curl \ + make \ + file \ + pkg-config \ + && rm -rf /var/lib/apt/lists/* +COPY --from=builder /kaldi /kaldi +ENV PATH="${PATH}:/kaldi/bin" LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/kaldi/lib" diff --git a/.circleci/unittest/linux/docker/build_and_push.sh b/.circleci/unittest/linux/docker/build_and_push.sh new file mode 100644 index 0000000000000000000000000000000000000000..e7ced13ad3f718b71f047c217f3b7f71dd97f109 --- /dev/null +++ b/.circleci/unittest/linux/docker/build_and_push.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash + +set -euo pipefail + +if [ $# -ne 1 ]; then + printf "Usage %s \n\n" "$0" + exit 1 +fi + +datestr="$(date "+%Y%m%d")" +if [ "$1" = "cpu" ]; then + base_image="ubuntu:18.04" + image="pytorch/torchaudio_unittest_base:manylinux-${datestr}" +else + base_image="nvidia/cuda:$1-devel-ubuntu18.04" + docker pull "${base_image}" + image="pytorch/torchaudio_unittest_base:manylinux-cuda$1-${datestr}" +fi + +cd "$( dirname "${BASH_SOURCE[0]}" )" + +# docker build also accepts reading from STDIN +# but in that case, no context (other files) can be passed, so we write out Dockerfile +sed "s|BASE_IMAGE|${base_image}|g" Dockerfile > Dockerfile.tmp +docker build -t "${image}" -f Dockerfile.tmp . +docker push "${image}" diff --git a/.circleci/unittest/linux/docker/scripts/copy_kaldi_executables.sh b/.circleci/unittest/linux/docker/scripts/copy_kaldi_executables.sh new file mode 100644 index 0000000000000000000000000000000000000000..b0cf207143864aea078c25ef96a4b24a9f93ae10 --- /dev/null +++ b/.circleci/unittest/linux/docker/scripts/copy_kaldi_executables.sh @@ -0,0 +1,58 @@ +#!/usr/bin/env bash + +list_executables() { + # List up executables in the given directory + find "$1" -type f -executable +} + +list_kaldi_libraries() { + # List up shared libraries used by executables found in the given directory ($1) + # that reside in Kaldi directory ($2) + while read file; do + ldd "${file}" | grep -o "${2}.* "; + done < <(list_executables "$1") | sort -u +} + +set -euo pipefail + +kaldi_root="$(realpath "$1")" +target_dir="$(realpath "$2")" + +bin_dir="${target_dir}/bin" +lib_dir="${target_dir}/lib" + +mkdir -p "${bin_dir}" "${lib_dir}" + +# 1. Copy featbins +printf "Copying executables to %s\n" "${bin_dir}" +while read file; do + printf " %s\n" "${file}" + cp "${file}" "${bin_dir}" +done < <(list_executables "${kaldi_root}/src/featbin") + +# 2. Copy dependent libraries from Kaldi +printf "Copying libraries to %s\n" "${lib_dir}" +while read file; do + printf " %s\n" "$file" + # If it is not symlink, just copy to the target directory + if [ ! -L "${file}" ]; then + cp "${file}" "${lib_dir}" + continue + fi + + # If it is symlink, + # 1. Copy the actual library to the target directory. + library="$(realpath "${file}")" + cp "${library}" "${lib_dir}" + # 2. then if the name of the symlink is different from the actual library name, + # create the symlink in the target directory. + lib_name="$(basename "${library}")" + link_name="$(basename "${file}")" + if [ "${lib_name}" != "${link_name}" ]; then + printf " Linking %s -> %s\n" "${lib_name}" "${link_name}" + ( + cd "${lib_dir}" + ln -sf "${lib_name}" "${link_name}" + ) + fi +done < <(list_kaldi_libraries "${bin_dir}" "${kaldi_root}") diff --git a/.circleci/unittest/linux/scripts/install.sh b/.circleci/unittest/linux/scripts/install.sh new file mode 100644 index 0000000000000000000000000000000000000000..b7dc9da0442a3080418162e64956e189cd8ffa57 --- /dev/null +++ b/.circleci/unittest/linux/scripts/install.sh @@ -0,0 +1,72 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. + +set -e + +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 0. Activate conda env +eval "$("${conda_dir}/bin/conda" shell.bash hook)" +conda activate "${env_dir}" + +# 1. Install PyTorch +if [ -z "${CUDA_VERSION:-}" ] ; then + if [ "${os}" == MacOSX ] ; then + cudatoolkit='' + else + cudatoolkit="cpuonly" + fi +else + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" + cudatoolkit="cudatoolkit=${version}" +fi +printf "Installing PyTorch with %s\n" "${cudatoolkit}" +( + if [ "${os}" == MacOSX ] ; then + # TODO: this can be removed as soon as linking issue could be resolved + # see https://github.com/pytorch/pytorch/issues/62424 from details + MKL_CONSTRAINT='mkl==2021.2.0' + else + MKL_CONSTRAINT='' + fi + set -x + conda install ${CONDA_CHANNEL_FLAGS:-} -y -c "pytorch-${UPLOAD_CHANNEL}" $MKL_CONSTRAINT "pytorch-${UPLOAD_CHANNEL}::pytorch" ${cudatoolkit} +) + +# 2. Install torchaudio +printf "* Installing torchaudio\n" +git submodule update --init --recursive +python setup.py install + +# 3. Install Test tools +printf "* Installing test tools\n" +NUMBA_DEV_CHANNEL="" +if [[ "$(python --version)" = *3.9* ]]; then + # Numba isn't available for Python 3.9 except on the numba dev channel and building from source fails + # See https://github.com/librosa/librosa/issues/1270#issuecomment-759065048 + NUMBA_DEV_CHANNEL="-c numba/label/dev" +fi +# Note: installing librosa via pip fail because it will try to compile numba. +( + set -x + conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} 'librosa>=0.8.0' parameterized 'requests>=2.20' + pip install kaldi-io SoundFile coverage pytest pytest-cov scipy transformers expecttest unidecode inflect +) +# Install fairseq +git clone https://github.com/pytorch/fairseq +cd fairseq +git checkout e47a4c8 +pip install . diff --git a/.circleci/unittest/linux/scripts/run_clang_format.py b/.circleci/unittest/linux/scripts/run_clang_format.py new file mode 100644 index 0000000000000000000000000000000000000000..fd2913bd70eef1f0efefdc2d226fb1c21b6ef4f8 --- /dev/null +++ b/.circleci/unittest/linux/scripts/run_clang_format.py @@ -0,0 +1,340 @@ +#!/usr/bin/env python +"""A wrapper script around clang-format, suitable for linting multiple files +and to use for continuous integration. + +This is an alternative API for the clang-format command line. +It runs over multiple files and directories in parallel. +A diff output is produced and a sensible exit code is returned. + +""" + +import argparse +import codecs +import difflib +import fnmatch +import io +import multiprocessing +import os +import signal +import subprocess +import sys +import traceback + +from functools import partial + +try: + from subprocess import DEVNULL # py3k +except ImportError: + DEVNULL = open(os.devnull, "wb") + + +DEFAULT_EXTENSIONS = 'c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu' + + +class ExitStatus: + SUCCESS = 0 + DIFF = 1 + TROUBLE = 2 + + +def list_files(files, recursive=False, extensions=None, exclude=None): + if extensions is None: + extensions = [] + if exclude is None: + exclude = [] + + out = [] + for file in files: + if recursive and os.path.isdir(file): + for dirpath, dnames, fnames in os.walk(file): + fpaths = [os.path.join(dirpath, fname) for fname in fnames] + for pattern in exclude: + # os.walk() supports trimming down the dnames list + # by modifying it in-place, + # to avoid unnecessary directory listings. + dnames[:] = [ + x for x in dnames + if + not fnmatch.fnmatch(os.path.join(dirpath, x), pattern) + ] + fpaths = [ + x for x in fpaths if not fnmatch.fnmatch(x, pattern) + ] + for f in fpaths: + ext = os.path.splitext(f)[1][1:] + if ext in extensions: + out.append(f) + else: + out.append(file) + return out + + +def make_diff(file, original, reformatted): + return list( + difflib.unified_diff( + original, + reformatted, + fromfile='{}\t(original)'.format(file), + tofile='{}\t(reformatted)'.format(file), + n=3)) + + +class DiffError(Exception): + def __init__(self, message, errs=None): + super(DiffError, self).__init__(message) + self.errs = errs or [] + + +class UnexpectedError(Exception): + def __init__(self, message, exc=None): + super(UnexpectedError, self).__init__(message) + self.formatted_traceback = traceback.format_exc() + self.exc = exc + + +def run_clang_format_diff_wrapper(args, file): + try: + ret = run_clang_format_diff(args, file) + return ret + except DiffError: + raise + except Exception as e: + raise UnexpectedError('{}: {}: {}'.format(file, e.__class__.__name__, + e), e) + + +def run_clang_format_diff(args, file): + try: + with io.open(file, 'r', encoding='utf-8') as f: + original = f.readlines() + except IOError as exc: + raise DiffError(str(exc)) + invocation = [args.clang_format_executable, file] + + # Use of utf-8 to decode the process output. + # + # Hopefully, this is the correct thing to do. + # + # It's done due to the following assumptions (which may be incorrect): + # - clang-format will returns the bytes read from the files as-is, + # without conversion, and it is already assumed that the files use utf-8. + # - if the diagnostics were internationalized, they would use utf-8: + # > Adding Translations to Clang + # > + # > Not possible yet! + # > Diagnostic strings should be written in UTF-8, + # > the client can translate to the relevant code page if needed. + # > Each translation completely replaces the format string + # > for the diagnostic. + # > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation + + try: + proc = subprocess.Popen( + invocation, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + encoding='utf-8') + except OSError as exc: + raise DiffError( + "Command '{}' failed to start: {}".format( + subprocess.list2cmdline(invocation), exc + ) + ) + proc_stdout = proc.stdout + proc_stderr = proc.stderr + + # hopefully the stderr pipe won't get full and block the process + outs = list(proc_stdout.readlines()) + errs = list(proc_stderr.readlines()) + proc.wait() + if proc.returncode: + raise DiffError( + "Command '{}' returned non-zero exit status {}".format( + subprocess.list2cmdline(invocation), proc.returncode + ), + errs, + ) + return make_diff(file, original, outs), errs + + +def bold_red(s): + return '\x1b[1m\x1b[31m' + s + '\x1b[0m' + + +def colorize(diff_lines): + def bold(s): + return '\x1b[1m' + s + '\x1b[0m' + + def cyan(s): + return '\x1b[36m' + s + '\x1b[0m' + + def green(s): + return '\x1b[32m' + s + '\x1b[0m' + + def red(s): + return '\x1b[31m' + s + '\x1b[0m' + + for line in diff_lines: + if line[:4] in ['--- ', '+++ ']: + yield bold(line) + elif line.startswith('@@ '): + yield cyan(line) + elif line.startswith('+'): + yield green(line) + elif line.startswith('-'): + yield red(line) + else: + yield line + + +def print_diff(diff_lines, use_color): + if use_color: + diff_lines = colorize(diff_lines) + sys.stdout.writelines(diff_lines) + + +def print_trouble(prog, message, use_colors): + error_text = 'error:' + if use_colors: + error_text = bold_red(error_text) + print("{}: {} {}".format(prog, error_text, message), file=sys.stderr) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + '--clang-format-executable', + metavar='EXECUTABLE', + help='path to the clang-format executable', + default='clang-format') + parser.add_argument( + '--extensions', + help='comma separated list of file extensions (default: {})'.format( + DEFAULT_EXTENSIONS), + default=DEFAULT_EXTENSIONS) + parser.add_argument( + '-r', + '--recursive', + action='store_true', + help='run recursively over directories') + parser.add_argument('files', metavar='file', nargs='+') + parser.add_argument( + '-q', + '--quiet', + action='store_true') + parser.add_argument( + '-j', + metavar='N', + type=int, + default=0, + help='run N clang-format jobs in parallel' + ' (default number of cpus + 1)') + parser.add_argument( + '--color', + default='auto', + choices=['auto', 'always', 'never'], + help='show colored diff (default: auto)') + parser.add_argument( + '-e', + '--exclude', + metavar='PATTERN', + action='append', + default=[], + help='exclude paths matching the given glob-like pattern(s)' + ' from recursive search') + + args = parser.parse_args() + + # use default signal handling, like diff return SIGINT value on ^C + # https://bugs.python.org/issue14229#msg156446 + signal.signal(signal.SIGINT, signal.SIG_DFL) + try: + signal.SIGPIPE + except AttributeError: + # compatibility, SIGPIPE does not exist on Windows + pass + else: + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + colored_stdout = False + colored_stderr = False + if args.color == 'always': + colored_stdout = True + colored_stderr = True + elif args.color == 'auto': + colored_stdout = sys.stdout.isatty() + colored_stderr = sys.stderr.isatty() + + version_invocation = [args.clang_format_executable, str("--version")] + try: + subprocess.check_call(version_invocation, stdout=DEVNULL) + except subprocess.CalledProcessError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + return ExitStatus.TROUBLE + except OSError as e: + print_trouble( + parser.prog, + "Command '{}' failed to start: {}".format( + subprocess.list2cmdline(version_invocation), e + ), + use_colors=colored_stderr, + ) + return ExitStatus.TROUBLE + + retcode = ExitStatus.SUCCESS + files = list_files( + args.files, + recursive=args.recursive, + exclude=args.exclude, + extensions=args.extensions.split(',')) + + if not files: + return + + njobs = args.j + if njobs == 0: + njobs = multiprocessing.cpu_count() + 1 + njobs = min(len(files), njobs) + + if njobs == 1: + # execute directly instead of in a pool, + # less overhead, simpler stacktraces + it = (run_clang_format_diff_wrapper(args, file) for file in files) + pool = None + else: + pool = multiprocessing.Pool(njobs) + it = pool.imap_unordered( + partial(run_clang_format_diff_wrapper, args), files) + while True: + try: + outs, errs = next(it) + except StopIteration: + break + except DiffError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + retcode = ExitStatus.TROUBLE + sys.stderr.writelines(e.errs) + except UnexpectedError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + sys.stderr.write(e.formatted_traceback) + retcode = ExitStatus.TROUBLE + # stop at the first unexpected error, + # something could be very wrong, + # don't process all files unnecessarily + if pool: + pool.terminate() + break + else: + sys.stderr.writelines(errs) + if outs == []: + continue + if not args.quiet: + print_diff(outs, use_color=colored_stdout) + if retcode == ExitStatus.SUCCESS: + retcode = ExitStatus.DIFF + return retcode + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/.circleci/unittest/linux/scripts/run_style_checks.sh b/.circleci/unittest/linux/scripts/run_style_checks.sh new file mode 100644 index 0000000000000000000000000000000000000000..b1ef0f1e79bf8e5fdf5bcde7988df4edad93651e --- /dev/null +++ b/.circleci/unittest/linux/scripts/run_style_checks.sh @@ -0,0 +1,49 @@ +#!/usr/bin/env bash + +set -eux + +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +eval "$("${conda_dir}/bin/conda" shell.bash hook)" +conda activate "${env_dir}" + +# 1. Install tools +conda install flake8 +printf "Installed flake8: " +flake8 --version + +clangformat_path="${root_dir}/clang-format" +curl https://oss-clang-format.s3.us-east-2.amazonaws.com/linux64/clang-format-linux64 -o "${clangformat_path}" +chmod +x "${clangformat_path}" +printf "Installed clang-fortmat" +"${clangformat_path}" --version + +# 2. Run style checks +# We want to run all the style checks even if one of them fail. + +set +e + +exit_status=0 + +printf "\x1b[34mRunning flake8:\x1b[0m\n" +flake8 torchaudio test build_tools/setup_helpers docs/source/conf.py examples +status=$? +exit_status="$((exit_status+status))" +if [ "${status}" -ne 0 ]; then + printf "\x1b[31mflake8 failed. Check the format of Python files.\x1b[0m\n" +fi + +printf "\x1b[34mRunning clang-format:\x1b[0m\n" +"${this_dir}"/run_clang_format.py \ + -r torchaudio/csrc third_party/kaldi/src \ + --clang-format-executable "${clangformat_path}" \ + && git diff --exit-code +status=$? +exit_status="$((exit_status+status))" +if [ "${status}" -ne 0 ]; then + printf "\x1b[31mC++ files are not formatted. Please use clang-format to format CPP files.\x1b[0m\n" +fi +exit $exit_status diff --git a/.circleci/unittest/linux/scripts/run_test.sh b/.circleci/unittest/linux/scripts/run_test.sh new file mode 100644 index 0000000000000000000000000000000000000000..4e8d24748b0532e043dcd26589bec8b2e93683fd --- /dev/null +++ b/.circleci/unittest/linux/scripts/run_test.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +python -m torch.utils.collect_env + +export TORCHAUDIO_TEST_FAIL_IF_NO_EXTENSION=1 +export PATH="${PWD}/third_party/install/bin/:${PATH}" + +declare -a args=( + '-v' + '--cov=torchaudio' + "--junitxml=${PWD}/test-results/junit.xml" + '--durations' '20' +) + +cd test +pytest "${args[@]}" torchaudio_unittest +coverage html diff --git a/.circleci/unittest/linux/scripts/setup_env.sh b/.circleci/unittest/linux/scripts/setup_env.sh new file mode 100644 index 0000000000000000000000000000000000000000..085dc61a5bcd624c5b29f18ad1ba99a7b2af87de --- /dev/null +++ b/.circleci/unittest/linux/scripts/setup_env.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchaudio here, otherwise they also get cached. + +set -ex + +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget --quiet -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" + eval "$("${conda_dir}/bin/conda" shell.bash hook)" + conda update --quiet -y conda + printf "* Updating the base Python version to %s\n" "${PYTHON_VERSION}" + conda install --quiet -y python="${PYTHON_VERSION}" +else + eval "$("${conda_dir}/bin/conda" shell.bash hook)" +fi + + +# 2. Create test environment at ./env +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment with PYTHON_VERSION=%s\n" "${PYTHON_VERSION}\n" + conda create --prefix "${env_dir}" -y python="${PYTHON_VERSION}" +fi +conda activate "${env_dir}" + +# 3. Install minimal build tools +pip --quiet install cmake ninja diff --git a/.circleci/unittest/windows/README.md b/.circleci/unittest/windows/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2c06af62bdc8516d4d73ac5fa5f259cfc2753410 --- /dev/null +++ b/.circleci/unittest/windows/README.md @@ -0,0 +1,4 @@ +This directory contains; + + - scripts + Scripts used by CircleCI to run unit tests. diff --git a/.circleci/unittest/windows/scripts/environment.yml b/.circleci/unittest/windows/scripts/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..16225c9e3773e03b24ba5e1bb52e17283506bef7 --- /dev/null +++ b/.circleci/unittest/windows/scripts/environment.yml @@ -0,0 +1,16 @@ +channels: + - defaults +dependencies: + - flake8 + - pytest + - pytest-cov + - codecov + - scipy >= 1.4.1 + - pip + - pip: + - kaldi-io + - PySoundFile + - future + - parameterized + - dataclasses + - expecttest diff --git a/.circleci/unittest/windows/scripts/install.sh b/.circleci/unittest/windows/scripts/install.sh new file mode 100644 index 0000000000000000000000000000000000000000..f99f46aa8bac505455243fac89e330976fadd900 --- /dev/null +++ b/.circleci/unittest/windows/scripts/install.sh @@ -0,0 +1,66 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. + +set -ex + +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +cd "${root_dir}" + +# 0. Activate conda env +eval "$("${conda_dir}/Scripts/conda.exe" 'shell.bash' 'hook')" +conda activate "${env_dir}" + +source "$this_dir/set_cuda_envs.sh" + +# 1. Install PyTorch +if [ -z "${CUDA_VERSION:-}" ] ; then + cudatoolkit="cpuonly" +else + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" + cudatoolkit="cudatoolkit=${version}" +fi +printf "Installing PyTorch with %s\n" "${cudatoolkit}" +conda install -y -c "pytorch-${UPLOAD_CHANNEL}" -c conda-forge "pytorch-${UPLOAD_CHANNEL}"::pytorch "${cudatoolkit}" pytest + +torch_cuda=$(python -c "import torch; print(torch.cuda.is_available())") +echo torch.cuda.is_available is $torch_cuda + +if [ ! -z "${CUDA_VERSION:-}" ] ; then + if [ "$torch_cuda" == "False" ]; then + echo "torch with cuda installed but torch.cuda.is_available() is False" + exit 1 + fi +fi + +# 2. Install torchaudio +printf "* Installing torchaudio\n" +git submodule update --init --recursive +"$root_dir/packaging/vc_env_helper.bat" python setup.py install + +# 3. Install Test tools +printf "* Installing test tools\n" +NUMBA_DEV_CHANNEL="" +if [[ "$(python --version)" = *3.9* ]]; then + # Numba isn't available for Python 3.9 except on the numba dev channel and building from source fails + # See https://github.com/librosa/librosa/issues/1270#issuecomment-759065048 + NUMBA_DEV_CHANNEL="-c numba/label/dev" +fi +# Note: installing librosa via pip fail because it will try to compile numba. +( + set -x + conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} 'librosa>=0.8.0' parameterized 'requests>=2.20' + pip install kaldi-io SoundFile coverage pytest pytest-cov scipy transformers expecttest unidecode inflect +) +# Install fairseq +git clone https://github.com/pytorch/fairseq +cd fairseq +git checkout e47a4c8 +pip install . diff --git a/.circleci/unittest/windows/scripts/install_conda.bat b/.circleci/unittest/windows/scripts/install_conda.bat new file mode 100644 index 0000000000000000000000000000000000000000..6052ad08b106accec140ef3f0e27cb4fe893377a --- /dev/null +++ b/.circleci/unittest/windows/scripts/install_conda.bat @@ -0,0 +1 @@ +start /wait "" "%miniconda_exe%" /S /InstallationType=JustMe /RegisterPython=0 /AddToPath=0 /D=%tmp_conda% diff --git a/.circleci/unittest/windows/scripts/run_test.sh b/.circleci/unittest/windows/scripts/run_test.sh new file mode 100644 index 0000000000000000000000000000000000000000..f5ec80e04327b49852e4a5d498fa0dd1165e85fc --- /dev/null +++ b/.circleci/unittest/windows/scripts/run_test.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash + +set -ex + +eval "$(./conda/Scripts/conda.exe 'shell.bash' 'hook')" +conda activate ./env + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +source "$this_dir/set_cuda_envs.sh" + +python -m torch.utils.collect_env +cd test +pytest --cov=torchaudio --junitxml=../test-results/junit.xml -v --durations 20 torchaudio_unittest +coverage html diff --git a/.circleci/unittest/windows/scripts/set_cuda_envs.sh b/.circleci/unittest/windows/scripts/set_cuda_envs.sh new file mode 100644 index 0000000000000000000000000000000000000000..37b53d020da0ea836fcfe0fce539008c05483358 --- /dev/null +++ b/.circleci/unittest/windows/scripts/set_cuda_envs.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash +set -ex + +echo CU_VERSION is "${CU_VERSION}" +echo CUDA_VERSION is "${CUDA_VERSION}" + +# Currenly, CU_VERSION and CUDA_VERSION are not consistent. +# to understand this code, please checck out https://github.com/pytorch/vision/issues/4443 +version="cpu" +if [[ ! -z "${CUDA_VERSION}" ]] ; then + version="$CUDA_VERSION" +else + if [[ ${#CU_VERSION} -eq 5 ]]; then + version="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi +fi + +# Don't use if [[ "$version" == "cpu" ]]; then exit 0 fi. +# It would exit the shell. One result is cpu tests would not run if the shell exit. +# Unless there's an error, Don't exit. +if [[ "$version" != "cpu" ]]; then + # set cuda envs + export PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${version}/bin:/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${version}/libnvvp:$PATH" + export CUDA_PATH_V${version/./_}="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v${version}" + export CUDA_PATH="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v${version}" + + if [ ! -d "$CUDA_PATH" ] + then + echo "$CUDA_PATH" does not exist + exit 1 + fi + + # check cuda driver version + for path in '/c/Program Files/NVIDIA Corporation/NVSMI/nvidia-smi.exe' /c/Windows/System32/nvidia-smi.exe; do + if [[ -x "$path" ]]; then + "$path" || echo "true"; + break + fi + done + + which nvcc + nvcc --version + env | grep CUDA +fi diff --git a/.circleci/unittest/windows/scripts/setup_env.sh b/.circleci/unittest/windows/scripts/setup_env.sh new file mode 100644 index 0000000000000000000000000000000000000000..5f092bfceb5e75c8b0578a4dd37188143fac3330 --- /dev/null +++ b/.circleci/unittest/windows/scripts/setup_env.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchaudio here, otherwise they also get cached. + +set -e + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + export tmp_conda="$(echo $conda_dir | tr '/' '\\')" + export miniconda_exe="$(echo $root_dir | tr '/' '\\')\\miniconda.exe" + curl --silent --output miniconda.exe https://repo.anaconda.com/miniconda/Miniconda3-latest-Windows-x86_64.exe -O + "$this_dir/install_conda.bat" + unset tmp_conda + unset miniconda_exe + eval "$("${conda_dir}/Scripts/conda.exe" 'shell.bash' 'hook')" + conda update --quiet -y conda + printf "* Updating the base Python version to %s\n" "${PYTHON_VERSION}" + conda install --quiet -y python="$PYTHON_VERSION" +else + eval "$("${conda_dir}/Scripts/conda.exe" 'shell.bash' 'hook')" +fi + +# 2. Create test environment at ./env +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment with PYTHON_VERSION=%s\n" "${PYTHON_VERSION}" + conda create --prefix "${env_dir}" -y python="${PYTHON_VERSION}" +fi +conda activate "${env_dir}" + +# 3. Install minimal build tools +pip --quiet install cmake ninja diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000000000000000000000000000000000..73304266bd671de2c1bf7b78bff6b3e6f63ffa69 --- /dev/null +++ b/.clang-format @@ -0,0 +1,88 @@ +--- +AccessModifierOffset: -1 +AlignAfterOpenBracket: AlwaysBreak +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlinesLeft: true +AlignOperands: false +AlignTrailingComments: false +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Empty +AllowShortIfStatementsOnASingleLine: false +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: true +AlwaysBreakTemplateDeclarations: true +BinPackArguments: false +BinPackParameters: false +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Attach +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: false +ColumnLimit: 80 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +ForEachMacros: [ FOR_EACH_RANGE, FOR_EACH, ] +IncludeCategories: + - Regex: '^<.*\.h(pp)?>' + Priority: 1 + - Regex: '^<.*' + Priority: 2 + - Regex: '.*' + Priority: 3 +IndentCaseLabels: true +IndentWidth: 2 +IndentWrappedFunctionNames: false +KeepEmptyLinesAtTheStartOfBlocks: false +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: false +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 2000000 +PointerAlignment: Left +ReflowComments: true +SortIncludes: true +SpaceAfterCStyleCast: false +SpaceBeforeAssignmentOperators: true +SpaceBeforeParens: ControlStatements +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Cpp11 +TabWidth: 8 +UseTab: Never +... diff --git a/.clang-tidy b/.clang-tidy new file mode 100644 index 0000000000000000000000000000000000000000..e2d7eb387d32bea625ce2e13ed2c6b5e4102e4bb --- /dev/null +++ b/.clang-tidy @@ -0,0 +1,36 @@ +--- +# NOTE there must be no spaces before the '-' and check name. +# If you edit this list, please verify list of enabled check with +# clang-tidy --list-checks +InheritParentConfig: true +Checks: ' +bugprone-*, +-bugprone-forward-declaration-namespace, +-bugprone-macro-parentheses, +-clang-analyzer-*, +cppcoreguidelines-*, +-cppcoreguidelines-interfaces-global-init, +-cppcoreguidelines-owning-memory, +-cppcoreguidelines-pro-bounds-array-to-pointer-decay, +-cppcoreguidelines-pro-bounds-constant-array-index, +-cppcoreguidelines-pro-bounds-pointer-arithmetic, +-cppcoreguidelines-pro-type-cstyle-cast, +-cppcoreguidelines-pro-type-reinterpret-cast, +-cppcoreguidelines-pro-type-static-cast-downcast, +-cppcoreguidelines-pro-type-union-access, +-cppcoreguidelines-pro-type-vararg, +-cppcoreguidelines-special-member-functions, +-facebook-hte-RelativeInclude, +hicpp-exception-baseclass, +hicpp-avoid-goto, +modernize-*, +-modernize-return-braced-init-list, +-modernize-use-auto, +-modernize-use-default-member-init, +-modernize-use-using, +performance-unnecessary-value-param, +' +HeaderFilterRegex: 'torchaudio/.*' +AnalyzeTemporaryDtors: false +CheckOptions: +... diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000000000000000000000000000000000000..f344aec879bcdb270c83b44f595072508c303656 --- /dev/null +++ b/.flake8 @@ -0,0 +1,4 @@ +[flake8] +max-line-length = 120 +ignore = E305,E402,E721,E741,F405,W503,W504,F999 +exclude = build,docs/source,_ext,third_party diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..fb21d6183e8f8e2d6d87da82ba5e587712d9082f --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# To exclude autogenerated files from code reviews +.circleci/config.yml linguist-generated=true diff --git a/.github/ISSUE_TEMPLATE/bug-report.md b/.github/ISSUE_TEMPLATE/bug-report.md new file mode 100644 index 0000000000000000000000000000000000000000..fb50bd85040f16c4a7246bc154b9a55632bef78d --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report.md @@ -0,0 +1,53 @@ +--- +name: "\U0001F41B Bug Report" +about: Submit a bug report to help us improve Torchaudio + +--- + +## 🐛 Bug + + + +## To Reproduce + +Steps to reproduce the behavior: + +1. +1. +1. + + + +## Expected behavior + + + +## Environment + + - What commands did you used to install torchaudio (conda/pip/build from source)? + - If you are building from source, which commit is it? + - What does `torchaudio.__version__` print? (If applicable) + +Please copy and paste the output from our +[environment collection script](https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py) +(or fill out the checklist below manually). + +You can get the script and run it with: +``` +wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py +# For security purposes, please check the contents of collect_env.py before running it. +python collect_env.py +``` + + - PyTorch Version (e.g., 1.0): + - OS (e.g., Linux): + - How you installed PyTorch (`conda`, `pip`, source): + - Build command you used (if compiling from source): + - Python version: + - CUDA/cuDNN version: + - GPU models and configuration: + - Any other relevant information: + +## Additional context + + diff --git a/.github/ISSUE_TEMPLATE/documentation.md b/.github/ISSUE_TEMPLATE/documentation.md new file mode 100644 index 0000000000000000000000000000000000000000..ae745121c50c4b081dcd428e9db0b4f67c129b2f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/documentation.md @@ -0,0 +1,9 @@ +--- +name: "\U0001F4DA Documentation" +about: Report an issue related to https://pytorch.org/audio + +--- + +## 📚 Documentation + + diff --git a/.github/ISSUE_TEMPLATE/feature-request.md b/.github/ISSUE_TEMPLATE/feature-request.md new file mode 100644 index 0000000000000000000000000000000000000000..f3896e5368fe602705cc6bb885ee7514626770fe --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature-request.md @@ -0,0 +1,24 @@ +--- +name: "\U0001F680Feature Request" +about: Submit a proposal/request for a new Torchaudio feature + +--- + +## 🚀 Feature + + +## Motivation + + + +## Pitch + + + +## Alternatives + + + +## Additional context + + diff --git a/.github/ISSUE_TEMPLATE/questions-help-support.md b/.github/ISSUE_TEMPLATE/questions-help-support.md new file mode 100644 index 0000000000000000000000000000000000000000..77bfb55b9a468a6df6c9459c4735f4b35cba0c45 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/questions-help-support.md @@ -0,0 +1,13 @@ +--- +name: "❓Questions/Help/Support" +about: Do you need support? We have resources. + +--- + +## ❓ Questions and Help + +### Please note that this issue tracker is not a help form and this issue will be closed. + +We have a set of [listed resources available on the website](https://pytorch.org/resources). Our primary means of support is our discussion forum: + +- [Discussion Forum](https://discuss.pytorch.org/) diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml new file mode 100644 index 0000000000000000000000000000000000000000..d9daca0fb2ed66af6d510cd71117b3c1bf72850e --- /dev/null +++ b/.github/pytorch-probot.yml @@ -0,0 +1 @@ +tracking_issue: 736 diff --git a/.github/workflows/bandit.yml b/.github/workflows/bandit.yml new file mode 100644 index 0000000000000000000000000000000000000000..84200b438a9e8bd10e708d6a93373f30433af6ce --- /dev/null +++ b/.github/workflows/bandit.yml @@ -0,0 +1,23 @@ +# GitHub Actions Bandit Workflow + +name: Bandit + +on: + pull_request: + branches: [ main ] + + workflow_dispatch: + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + # Task will fail if any high-severity issues are found + # Ignoring submodules + - name: Run Bandit Security Analysis + run: | + python -m pip install bandit + python -m bandit -r . -x ./third_party -lll diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 0000000000000000000000000000000000000000..cf5358627db83b2fbfc4eacf7b363a81d480c9f4 --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,37 @@ +# GitHub Actions CodeQL Workflow + +name: CodeQL + +on: + pull_request: + branches: [ main ] + + workflow_dispatch: + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Initialize CodeQL + uses: github/codeql-action/init@v1 + with: + languages: python, cpp + + - name: Update submodules + run: git submodule update --init --recursive + + - name: Install Torch + run: | + python -m pip install cmake ninja + python -m pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + + - name: Build TorchAudio + run: USE_CUDA=0 python setup.py develop --user + + # If any code scanning alerts are found, they will be under Security -> CodeQL + # Link: https://github.com/pytorch/audio/security/code-scanning + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v1 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ae628691f53e207a38adee1780dfad2c43586010 --- /dev/null +++ b/.gitignore @@ -0,0 +1,127 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# temp files +~* +*.swp + +# C extensions / folders +*.so +_ext/ + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +docs/src/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints/ + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# dotenv +.env + +# virtualenv +.venv +venv/ +ENV/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# PyCharm project settings +.idea + +# OSX dir files +.DS_Store + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# Generated Files +test/assets/sinewave.wav +torchaudio/version.py +gen.yml + +# Examples +examples/interactive_asr/data/*.txt +examples/interactive_asr/data/*.model +examples/interactive_asr/data/*.pt + +# third parties +third_party/install/ +third_party/sox/archives/ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..724846120cac06e735726f787ef3913665ae57d2 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,4 @@ +[submodule "kaldi"] + path = third_party/kaldi/submodule + url = https://github.com/kaldi-asr/kaldi + ignore = dirty diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..5e2555a2532e8e40302c5c7c40a5567b3f0b399b --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,129 @@ +cmake_minimum_required(VERSION 3.18 FATAL_ERROR) + +# Most of the configurations are taken from PyTorch +# https://github.com/pytorch/pytorch/blob/0c9fb4aff0d60eaadb04e4d5d099fb1e1d5701a9/CMakeLists.txt + +# Use compiler ID "AppleClang" instead of "Clang" for XCode. +# Not setting this sometimes makes XCode C compiler gets detected as "Clang", +# even when the C++ one is detected as "AppleClang". +cmake_policy(SET CMP0010 NEW) +cmake_policy(SET CMP0025 NEW) + +# Suppress warning flags in default MSVC configuration. It's not +# mandatory that we do this (and we don't if cmake is old), but it's +# nice when it's possible, and it's possible on our Windows configs. +if(NOT CMAKE_VERSION VERSION_LESS 3.15.0) + cmake_policy(SET CMP0092 NEW) +endif() + +project(torchaudio) + + +# check and set CMAKE_CXX_STANDARD +string(FIND "${CMAKE_CXX_FLAGS}" "-std=c++" env_cxx_standard) +if(env_cxx_standard GREATER -1) + message( + WARNING "C++ standard version definition detected in environment variable." + "PyTorch requires -std=c++14. Please remove -std=c++ settings in your environment.") +endif() + +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_C_STANDARD 11) + +# https://developercommunity.visualstudio.com/t/VS-16100-isnt-compatible-with-CUDA-11/1433342 +if(MSVC) + if(USE_CUDA) + set(CMAKE_CXX_STANDARD 17) + endif() +endif() + + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +# Apple specific +if(APPLE) + # Get clang version on macOS + execute_process( COMMAND ${CMAKE_CXX_COMPILER} --version OUTPUT_VARIABLE clang_full_version_string ) + string(REGEX REPLACE "Apple LLVM version ([0-9]+\\.[0-9]+).*" "\\1" CLANG_VERSION_STRING ${clang_full_version_string}) + message( STATUS "CLANG_VERSION_STRING: " ${CLANG_VERSION_STRING} ) + + # RPATH stuff + set(CMAKE_MACOSX_RPATH ON) + + set(CMAKE_SHARED_LIBRARY_SUFFIX ".so") +endif() + + +# Options +option(BUILD_SOX "Build libsox statically" ON) +option(BUILD_KALDI "Build kaldi statically" ON) +option(BUILD_RNNT "Enable RNN transducer" ON) +option(BUILD_TORCHAUDIO_PYTHON_EXTENSION "Build Python extension" OFF) +option(USE_CUDA "Enable CUDA support" OFF) +option(USE_ROCM "Enable ROCM support" OFF) + + +# check that USE_CUDA and USE_ROCM are not set at the same time +if(USE_CUDA AND USE_ROCM) + message(FATAL "CUDA and ROCm are mutually exclusive") +endif() + +if(USE_ROCM) + # Find the HIP package, set the HIP paths, load the HIP CMake. + include(cmake/LoadHIP.cmake) + if(NOT PYTORCH_FOUND_HIP) + set(USE_ROCM OFF) + endif() +endif() + +if(USE_CUDA) + enable_language(CUDA) +endif() + +find_package(Torch REQUIRED) + +# https://github.com/pytorch/pytorch/issues/54174 +function(CUDA_CONVERT_FLAGS EXISTING_TARGET) + get_property(old_flags TARGET ${EXISTING_TARGET} PROPERTY INTERFACE_COMPILE_OPTIONS) + if(NOT "${old_flags}" STREQUAL "") + string(REPLACE ";" "," CUDA_flags "${old_flags}") + set_property(TARGET ${EXISTING_TARGET} PROPERTY INTERFACE_COMPILE_OPTIONS + "$<$>:${old_flags}>$<$>:-Xcompiler=${CUDA_flags}>" + ) + endif() +endfunction() + +if(MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4819") + if(USE_CUDA) + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=/wd4819") + foreach(diag cc_clobber_ignored integer_sign_change useless_using_declaration + set_but_not_used field_without_dll_interface + base_class_has_different_dll_interface + dll_interface_conflict_none_assumed + dll_interface_conflict_dllexport_assumed + implicit_return_from_non_void_function + unsigned_compare_with_zero + declared_but_not_referenced + bad_friend_decl) + string(APPEND CMAKE_CUDA_FLAGS " -Xcudafe --diag_suppress=${diag}") + endforeach() + CUDA_CONVERT_FLAGS(torch_cpu) + if(TARGET torch_cuda) + CUDA_CONVERT_FLAGS(torch_cuda) + endif() + if(TARGET torch_cuda_cu) + CUDA_CONVERT_FLAGS(torch_cuda_cu) + endif() + if(TARGET torch_cuda_cpp) + CUDA_CONVERT_FLAGS(torch_cuda_cpp) + endif() + endif() +endif() + +# TORCH_CXX_FLAGS contains the same -D_GLIBCXX_USE_CXX11_ABI value as PyTorch +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall ${TORCH_CXX_FLAGS}") + +add_subdirectory(third_party) +add_subdirectory(torchaudio/csrc) diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..b91e23b17c023f10a34c7973c6f8614eed61ad1f --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,76 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic +address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a +professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..897e9907d285813760c002117635bd01370f3dd6 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,174 @@ +# Contributing to Torchaudio +We want to make contributing to this project as easy and transparent as possible. + +## TL;DR + +Please let us know if you encounter a bug by filing an [issue](https://github.com/pytorch/audio/issues). + +We appreciate all contributions. If you are planning to contribute back +bug-fixes, please do so without any further discussion. + +If you plan to contribute new features, utility functions or extensions to the +core, please first open an issue and discuss the feature with us. Sending a PR +without discussion might end up resulting in a rejected PR, because we might be +taking the core in a different direction than you might be aware of. + +Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the +safe disclosure of security bugs. In those cases, please go through the +process outlined on that page and do not file a public issue. + +Fixing bugs and implementing new features are not the only way you can +contribute. It also helps the project when you report problems you're facing, +and when you give a :+1: on issues that others reported and that are relevant +to you. + +You can also help by improving the documentation. This is no less important +than improving the library itself! If you find a typo in the documentation, +do not hesitate to submit a pull request. + +If you're not sure what you want to work on, you can pick an issue from the +[list of open issues labelled as "help +wanted"](https://github.com/pytorch/audio/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22). +Comment on the issue that you want to work on it and send a PR with your fix +(see below). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Development installation + +We recommend using a `conda` environment to contribute efficiently to +torchaudio. + +### Install PyTorch Nightly + +```bash +conda install pytorch -c pytorch-nightly +``` + +### Install Torchaudio + +```bash +# Install build-time dependencies +pip install cmake ninja pkgconfig +``` + +```bash +# Build torchaudio +git clone https://github.com/pytorch/audio.git +cd audio +git submodule update --init --recursive +python setup.py develop +# or, for OSX +# MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py develop +``` + +Some environmnet variables that change the build behavior +- `BUILD_SOX`: Deteremines whether build and bind libsox in non-Windows environments. (no effect in Windows as libsox integration is not available) Default value is 1 (build and bind). Use 0 for disabling it. +- `USE_CUDA`: Determines whether build the custom CUDA kernel. Default to the availability of CUDA-compatible GPUs. + +If you built sox, set the `PATH` variable so that the tests properly use the newly built `sox` binary: + +```bash +export PATH="/third_party/install/bin:${PATH}" +``` + +The following dependencies are also needed for testing: + +```bash +pip install typing pytest scipy numpy parameterized +``` + +Optional packages to install if you want to run related tests: + +- `librosa` +- `requests` +- `soundfile` +- `kaldi_io` +- `transformers` +- `fairseq` (it has to be newer than `0.10.2`, so you will need to install from + source. Commit `e6eddd80` is known to work.) +- `unidecode` (dependency for testing text preprocessing functions for examples/pipeline_tacotron2) +- `inflect` (dependency for testing text preprocessing functions for examples/pipeline_tacotron2) + +## Development Process + +If you plan to modify the code or documentation, please follow the steps below: + +1. Fork the repository and create your branch from `main`: `$ git checkout main && git checkout -b my_cool_feature` +2. If you have modified the code (new feature or bug-fix), [please add tests](test/torchaudio_unittest/). +3. If you have changed APIs, [update the documentation](#Documentation). + +For more details about pull requests, +please read [GitHub's guides](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/creating-a-pull-request). + +If you would like to contribute a new model, please see [here](#New-model). + +If you would like to contribute a new dataset, please see [here](#New-dataset). + +## Testing + +Please refer to our [testing guidelines](test/torchaudio_unittest/) for more +details. + +## Documentation + +Torchaudio uses [Google style](http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) +for formatting docstrings. Length of line inside docstrings block must be limited to 120 characters. + +To build the docs, first install the requirements: + +```bash +cd docs +pip install -r requirements.txt +``` + +Then: + +```bash +cd docs +make html +``` + +The built docs should now be available in `docs/build/html` + +## Conventions + +As a good software development practice, we try to stick to existing variable +names and shape (for tensors). +The following are some of the conventions that we follow. + +- We use an ellipsis "..." as a placeholder for the rest of the dimensions of a + tensor, e.g. optional batching and channel dimensions. If batching, the + "batch" dimension should come in the first diemension. +- Tensors are assumed to have "channel" dimension coming before the "time" + dimension. The bins in frequency domain (freq and mel) are assumed to come + before the "time" dimension but after the "channel" dimension. These + ordering makes the tensors consistent with PyTorch's dimensions. +- For size names, the prefix `n_` is used (e.g. "a tensor of size (`n_freq`, + `n_mels`)") whereas dimension names do not have this prefix (e.g. "a tensor of + dimension (channel, time)") + +Here are some of the examples of commonly used variables with thier names, +meanings, and shapes (or units): + +* `waveform`: a tensor of audio samples with dimensions (..., channel, time) +* `sample_rate`: the rate of audio dimensions (samples per second) +* `specgram`: a tensor of spectrogram with dimensions (..., channel, freq, time) +* `mel_specgram`: a mel spectrogram with dimensions (..., channel, mel, time) +* `hop_length`: the number of samples between the starts of consecutive frames +* `n_fft`: the number of Fourier bins +* `n_mels`, `n_mfcc`: the number of mel and MFCC bins +* `n_freq`: the number of bins in a linear spectrogram +* `f_min`: the lowest frequency of the lowest band in a spectrogram +* `f_max`: the highest frequency of the highest band in a spectrogram +* `win_length`: the length of the STFT window +* `window_fn`: for functions that creates windows e.g. `torch.hann_window` + +## License + +By contributing to Torchaudio, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..1bec23eaf1dd562ae3d3216420b1b1bbfbd39cbc --- /dev/null +++ b/LICENSE @@ -0,0 +1,25 @@ +BSD 2-Clause License + +Copyright (c) 2017 Facebook Inc. (Soumith Chintala), +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md index fa6bbe4a87f902da1722d2fcb3f489c9629437bf..c3072ec56ab93072e333142bfe7560cffde275b2 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,152 @@ -# Torchaudio +torchaudio: an audio library for PyTorch +======================================== +[![Build Status](https://circleci.com/gh/pytorch/audio.svg?style=svg)](https://app.circleci.com/pipelines/github/pytorch/audio) +[![Documentation](https://img.shields.io/badge/dynamic/json.svg?label=docs&url=https%3A%2F%2Fpypi.org%2Fpypi%2Ftorchaudio%2Fjson&query=%24.info.version&colorB=brightgreen&prefix=v)](https://pytorch.org/audio/) + +The aim of torchaudio is to apply [PyTorch](https://github.com/pytorch/pytorch) to +the audio domain. By supporting PyTorch, torchaudio follows the same philosophy +of providing strong GPU acceleration, having a focus on trainable features through +the autograd system, and having consistent style (tensor names and dimension names). +Therefore, it is primarily a machine learning library and not a general signal +processing library. The benefits of PyTorch can be seen in torchaudio through +having all the computations be through PyTorch operations which makes it easy +to use and feel like a natural extension. + +- [Support audio I/O (Load files, Save files)](http://pytorch.org/audio/stable/) + - Load a variety of audio formats, such as `wav`, `mp3`, `ogg`, `flac`, `opus`, `sphere`, into a torch Tensor using SoX + - [Kaldi (ark/scp)](http://pytorch.org/audio/stable/kaldi_io.html) +- [Dataloaders for common audio datasets](http://pytorch.org/audio/stable/datasets.html) +- Common audio transforms + - [Spectrogram, AmplitudeToDB, MelScale, MelSpectrogram, MFCC, MuLawEncoding, MuLawDecoding, Resample](http://pytorch.org/audio/stable/transforms.html) +- Compliance interfaces: Run code using PyTorch that align with other libraries + - [Kaldi: spectrogram, fbank, mfcc](https://pytorch.org/audio/stable/compliance.kaldi.html) + +Dependencies +------------ +* PyTorch (See below for the compatible versions) +* [optional] vesis84/kaldi-io-for-python commit cb46cb1f44318a5d04d4941cf39084c5b021241e or above + +The following are the corresponding ``torchaudio`` versions and supported Python versions. + +| ``torch`` | ``torchaudio`` | ``python`` | +| ------------------------ | ------------------------ | ------------------------------- | +| ``master`` / ``nightly`` | ``main`` / ``nightly`` | ``>=3.6``, ``<=3.9`` | +| ``1.9.0`` | ``0.9.0`` | ``>=3.6``, ``<=3.9`` | +| ``1.8.0`` | ``0.8.0`` | ``>=3.6``, ``<=3.9`` | +| ``1.7.1`` | ``0.7.2`` | ``>=3.6``, ``<=3.9`` | +| ``1.7.0`` | ``0.7.0`` | ``>=3.6``, ``<=3.8`` | +| ``1.6.0`` | ``0.6.0`` | ``>=3.6``, ``<=3.8`` | +| ``1.5.0`` | ``0.5.0`` | ``>=3.5``, ``<=3.8`` | +| ``1.4.0`` | ``0.4.0`` | ``==2.7``, ``>=3.5``, ``<=3.8`` | + + +Installation +------------ + +### Binary Distributions + +To install the latest version using anaconda, run: + +``` +conda install -c pytorch torchaudio +``` + +To install the latest pip wheels, run: + +``` +pip install torchaudio -f https://download.pytorch.org/whl/torch_stable.html +``` + +(If you do not have torch already installed, this will default to installing +torch from PyPI. If you need a different torch configuration, preinstall torch +before running this command.) + +### Nightly build + +Note that nightly build is built on PyTorch's nightly build. Therefore, you need to install the latest PyTorch when you use nightly build of torchaudio. + +**pip** + +``` +pip install --pre torchaudio -f https://download.pytorch.org/whl/nightly/torch_nightly.html +``` + +**conda** + +``` +conda install -y -c pytorch-nightly torchaudio +``` + +### From Source + +On non-Windows platforms, the build process builds libsox and codecs that torchaudio need to link to. It will fetch and build libmad, lame, flac, vorbis, opus, and libsox before building extension. This process requires `cmake` and `pkg-config`. libsox-based features can be disabled with `BUILD_SOX=0`. +The build process also builds the RNN transducer loss. This functionality can be disabled by setting the environment variable `BUILD_RNNT=0`. + +```bash +# Linux +python setup.py install + +# OSX +MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py install + +# Windows +# We need to use the MSVC x64 toolset for compilation, with Visual Studio's vcvarsall.bat or directly with vcvars64.bat. +# These batch files are under Visual Studio's installation folder, under 'VC\Auxiliary\Build\'. +# More information available at: +# https://docs.microsoft.com/en-us/cpp/build/how-to-enable-a-64-bit-visual-cpp-toolset-on-the-command-line?view=msvc-160#use-vcvarsallbat-to-set-a-64-bit-hosted-build-architecture +call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Auxiliary\Build\vcvarsall.bat" x64 && set BUILD_SOX=0 && python setup.py install +# or +call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Auxiliary\Build\vcvars64.bat" && set BUILD_SOX=0 && python setup.py install +``` + +This is known to work on linux and unix distributions such as Ubuntu and CentOS 7 and macOS. +If you try this on a new system and find a solution to make it work, feel free to share it by opening an issue. + +Quick Usage +----------- + +```python +import torchaudio + +waveform, sample_rate = torchaudio.load('foo.wav') # load tensor from file +torchaudio.save('foo_save.wav', waveform, sample_rate) # save tensor to file +``` + +Backend Dispatch +---------------- + +By default in OSX and Linux, torchaudio uses SoX as a backend to load and save files. +The backend can be changed to [SoundFile](https://pysoundfile.readthedocs.io/en/latest/) +using the following. See [SoundFile](https://pysoundfile.readthedocs.io/en/latest/) +for installation instructions. + +```python +import torchaudio +torchaudio.set_audio_backend("soundfile") # switch backend + +waveform, sample_rate = torchaudio.load('foo.wav') # load tensor from file, as usual +torchaudio.save('foo_save.wav', waveform, sample_rate) # save tensor to file, as usual +``` + +**Note** +- SoundFile currently does not support mp3. +- "soundfile" backend is not supported by TorchScript. + +API Reference +------------- + +API Reference is located here: http://pytorch.org/audio/ + +Contributing Guidelines +----------------------- + +Please refer to [CONTRIBUTING.md](./CONTRIBUTING.md) + +Disclaimer on Datasets +---------------------- + +This is a utility library that downloads and prepares public datasets. We do not host or distribute these datasets, vouch for their quality or fairness, or claim that you have license to use the dataset. It is your responsibility to determine whether you have permission to use the dataset under the dataset's license. + +If you're a dataset owner and wish to update any part of it (description, citation, etc.), or do not want your dataset to be included in this library, please get in touch through a GitHub issue. Thanks for your contribution to the ML community! +>>>>>>> init v0.10.0 diff --git a/build_tools/__init__.py b/build_tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build_tools/convert_fairseq_models.py b/build_tools/convert_fairseq_models.py new file mode 100644 index 0000000000000000000000000000000000000000..8bdb7cf174a158f443d14c2b2789b1968de7a99e --- /dev/null +++ b/build_tools/convert_fairseq_models.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +"""Convert a Wav2Vec2/HuBERT model published by fairseq into torchaudio format + +Examples + +``` +python convert_fairseq_models.py \ + --input-file hubert_base_ls960.pt \ + --output-file hubert_fairseq_base_ls960.pth + +python convert_fairseq_models.py \ + --input-file hubert_large_ll60k.pt \ + --output-file hubert_fairseq_large_ll60k.pth + +python convert_fairseq_models.py \ + --input-file hubert_large_ll60k_finetune_ls960.pt \ + --output-file hubert_fairseq_large_ll60k_asr_ls960.pth + +python convert_fairseq_models.py \ + --input-file hubert_xtralarge_ll60k.pt \ + --output-file hubert_fairseq_xlarge_ll60k.pth + +python convert_fairseq_models.py \ + --input-file hubert_xtralarge_ll60k_finetune_ls960.pt \ + --output-file hubert_fairseq_xlarge_ll60k_asr_ls960.pth +""" + +import argparse + +# Note: Avoiding the import of torch and fairseq on global scope as they are slow + + +def _parse_args(): + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + '--input-file', required=True, + help='Input model file.' + ) + parser.add_argument( + '--output-file', required=False, + help='Output model file.' + ) + parser.add_argument( + '--dict-dir', + help=( + 'Directory where letter vocabulary file, `dict.ltr.txt`, is found. ' + 'Required when loading wav2vec2 model. ' + 'https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt' + ) + ) + return parser.parse_args() + + +def _load_model(input_file, dict_dir): + import fairseq + + overrides = {} if dict_dir is None else {'data': dict_dir} + models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( + [input_file], arg_overrides=overrides, + ) + return models[0] + + +def _import_model(model): + from torchaudio.models.wav2vec2.utils import import_fairseq_model + + if model.__class__.__name__ in ['HubertCtc', 'Wav2VecCtc']: + model = model.w2v_encoder + model = import_fairseq_model(model) + return model + + +def _main(args): + import torch + model = _load_model(args.input_file, args.dict_dir) + model = _import_model(model) + torch.save(model.state_dict(), args.output_file) + + +if __name__ == '__main__': + _main(_parse_args()) diff --git a/build_tools/setup_helpers/__init__.py b/build_tools/setup_helpers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7afa3f31cd7a6a999cc0a1b939a1fe61c2647dbd --- /dev/null +++ b/build_tools/setup_helpers/__init__.py @@ -0,0 +1 @@ +from .extension import * # noqa diff --git a/build_tools/setup_helpers/extension.py b/build_tools/setup_helpers/extension.py new file mode 100644 index 0000000000000000000000000000000000000000..8f062c6ff8da407a925a8452137a19f843007625 --- /dev/null +++ b/build_tools/setup_helpers/extension.py @@ -0,0 +1,139 @@ +import os +import platform +import subprocess +from pathlib import Path +import distutils.sysconfig + +from setuptools import Extension +from setuptools.command.build_ext import build_ext +import torch + +__all__ = [ + 'get_ext_modules', + 'CMakeBuild', +] + +_THIS_DIR = Path(__file__).parent.resolve() +_ROOT_DIR = _THIS_DIR.parent.parent.resolve() +_TORCHAUDIO_DIR = _ROOT_DIR / 'torchaudio' + + +def _get_build(var, default=False): + if var not in os.environ: + return default + + val = os.environ.get(var, '0') + trues = ['1', 'true', 'TRUE', 'on', 'ON', 'yes', 'YES'] + falses = ['0', 'false', 'FALSE', 'off', 'OFF', 'no', 'NO'] + if val in trues: + return True + if val not in falses: + print( + f'WARNING: Unexpected environment variable value `{var}={val}`. ' + f'Expected one of {trues + falses}') + return False + + +_BUILD_SOX = False if platform.system() == 'Windows' else _get_build("BUILD_SOX", True) +_BUILD_KALDI = False if platform.system() == 'Windows' else _get_build("BUILD_KALDI", True) +_BUILD_RNNT = _get_build("BUILD_RNNT", True) +_USE_ROCM = _get_build("USE_ROCM", torch.cuda.is_available() and torch.version.hip is not None) +_USE_CUDA = _get_build("USE_CUDA", torch.cuda.is_available() and torch.version.hip is None) +_TORCH_CUDA_ARCH_LIST = os.environ.get('TORCH_CUDA_ARCH_LIST', None) + + +def get_ext_modules(): + return [ + Extension(name='torchaudio.lib.libtorchaudio', sources=[]), + Extension(name='torchaudio._torchaudio', sources=[]), + ] + + +# Based off of +# https://github.com/pybind/cmake_example/blob/580c5fd29d4651db99d8874714b07c0c49a53f8a/setup.py +class CMakeBuild(build_ext): + def run(self): + try: + subprocess.check_output(['cmake', '--version']) + except OSError: + raise RuntimeError("CMake is not available.") from None + super().run() + + def build_extension(self, ext): + # Since two library files (libtorchaudio and _torchaudio) need to be + # recognized by setuptools, we instantiate `Extension` twice. (see `get_ext_modules`) + # This leads to the situation where this `build_extension` method is called twice. + # However, the following `cmake` command will build all of them at the same time, + # so, we do not need to perform `cmake` twice. + # Therefore we call `cmake` only for `torchaudio._torchaudio`. + if ext.name != 'torchaudio._torchaudio': + return + + extdir = os.path.abspath( + os.path.dirname(self.get_ext_fullpath(ext.name))) + + # required for auto-detection of auxiliary "native" libs + if not extdir.endswith(os.path.sep): + extdir += os.path.sep + + cfg = "Debug" if self.debug else "Release" + + cmake_args = [ + f"-DCMAKE_BUILD_TYPE={cfg}", + f"-DCMAKE_PREFIX_PATH={torch.utils.cmake_prefix_path}", + f"-DCMAKE_INSTALL_PREFIX={extdir}", + '-DCMAKE_VERBOSE_MAKEFILE=ON', + f"-DPython_INCLUDE_DIR={distutils.sysconfig.get_python_inc()}", + f"-DBUILD_SOX:BOOL={'ON' if _BUILD_SOX else 'OFF'}", + f"-DBUILD_KALDI:BOOL={'ON' if _BUILD_KALDI else 'OFF'}", + f"-DBUILD_RNNT:BOOL={'ON' if _BUILD_RNNT else 'OFF'}", + "-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON", + f"-DUSE_ROCM:BOOL={'ON' if _USE_ROCM else 'OFF'}", + f"-DUSE_CUDA:BOOL={'ON' if _USE_CUDA else 'OFF'}", + ] + build_args = [ + '--target', 'install' + ] + # Pass CUDA architecture to cmake + if _TORCH_CUDA_ARCH_LIST is not None: + # Convert MAJOR.MINOR[+PTX] list to new style one + # defined at https://cmake.org/cmake/help/latest/prop_tgt/CUDA_ARCHITECTURES.html + _arches = _TORCH_CUDA_ARCH_LIST.replace('.', '').split(";") + _arches = [arch[:-4] if arch.endswith("+PTX") else f"{arch}-real" for arch in _arches] + cmake_args += [f"-DCMAKE_CUDA_ARCHITECTURES={';'.join(_arches)}"] + + # Default to Ninja + if 'CMAKE_GENERATOR' not in os.environ or platform.system() == 'Windows': + cmake_args += ["-GNinja"] + if platform.system() == 'Windows': + import sys + python_version = sys.version_info + cmake_args += [ + "-DCMAKE_C_COMPILER=cl", + "-DCMAKE_CXX_COMPILER=cl", + f"-DPYTHON_VERSION={python_version.major}.{python_version.minor}", + ] + + # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level + # across all generators. + if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: + # self.parallel is a Python 3 only way to set parallel jobs by hand + # using -j in the build_ext call, not supported by pip or PyPA-build. + if hasattr(self, "parallel") and self.parallel: + # CMake 3.12+ only. + build_args += ["-j{}".format(self.parallel)] + + if not os.path.exists(self.build_temp): + os.makedirs(self.build_temp) + + subprocess.check_call( + ["cmake", str(_ROOT_DIR)] + cmake_args, cwd=self.build_temp) + subprocess.check_call( + ["cmake", "--build", "."] + build_args, cwd=self.build_temp) + + def get_ext_filename(self, fullname): + ext_filename = super().get_ext_filename(fullname) + ext_filename_parts = ext_filename.split('.') + without_abi = ext_filename_parts[:-2] + ext_filename_parts[-1:] + ext_filename = '.'.join(without_abi) + return ext_filename diff --git a/build_tools/travis/install.sh b/build_tools/travis/install.sh new file mode 100644 index 0000000000000000000000000000000000000000..3586d5bd08f9369b2c0a7ffcc854829fbbc82445 --- /dev/null +++ b/build_tools/travis/install.sh @@ -0,0 +1,81 @@ +#!/bin/bash +# This script is meant to be called by the "install" step defined in +# .travis.yml. See http://docs.travis-ci.com/ for more details. +# The behavior of the script is controlled by environment variabled defined +# in the .travis.yml in the top level folder of the project. + + set -e + + echo 'List files from cached directories' +if [ -d $HOME/download ]; then + echo 'download:' + ls $HOME/download +fi +if [ -d $HOME/.cache/pip ]; then + echo 'pip:' + ls $HOME/.cache/pip +fi + + # Deactivate the travis-provided virtual environment and setup a +# conda-based environment instead +deactivate + + # Add the miniconda bin directory to $PATH +export PATH=/home/travis/miniconda3/bin:$PATH +echo $PATH + + # Use the miniconda installer for setup of conda itself +pushd . +cd +mkdir -p download +cd download +if [[ ! -f /home/travis/miniconda3/bin/activate ]] +then + if [[ ! -f miniconda.sh ]] + then + wget http://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh \ + -O miniconda.sh + fi + chmod +x miniconda.sh && ./miniconda.sh -b -f + conda update --yes conda + echo "Creating environment to run tests in." + conda create -n testenv --yes python="$PYTHON_VERSION" +fi +cd .. +popd + + # Activate the python environment we created. +source activate testenv + + # Install requirements via pip in our conda environment +conda install -y pytorch cpuonly -c pytorch-nightly +pip install -r requirements.txt + + # Install the following only if running tests +if [[ "$SKIP_INSTALL" != "true" ]]; then + # TorchAudio CPP Extensions + python setup.py install +fi + +if [[ "$RUN_EXAMPLE_TESTS" == "true" ]]; then + # Install dependencies + pip install sentencepiece PyAudio + + if [[ ! -d $HOME/download/fairseq ]]; then + # Install fairseq from source + git clone https://github.com/pytorch/fairseq $HOME/download/fairseq + fi + + pushd $HOME/download/fairseq + pip install --editable . + popd + + mkdir -p $HOME/download/data + # Install dictionary, sentence piece model, and model + # These are cached so they are not downloaded if they already exist + wget -nc -O $HOME/download/data/dict.txt https://download.pytorch.org/models/audio/dict.txt || true + wget -nc -O $HOME/download/data/spm.model https://download.pytorch.org/models/audio/spm.model || true + wget -nc -O $HOME/download/data/model.pt https://download.pytorch.org/models/audio/checkpoint_avg_60_80.pt || true +fi + +echo "Finished installation" diff --git a/build_tools/travis/test_script.sh b/build_tools/travis/test_script.sh new file mode 100644 index 0000000000000000000000000000000000000000..e854a211b3735a7890a7fe7d2bc7db4a8869d4a5 --- /dev/null +++ b/build_tools/travis/test_script.sh @@ -0,0 +1,49 @@ +#!/bin/bash +# This script is meant to be called by the "script" step defined in +# .travis.yml. See http://docs.travis-ci.com/ for more details. +# The behavior of the script is controlled by environment variabled defined +# in the .travis.yml in the top level folder of the project. +set -e + +python --version +python -c 'import torch;print("torch:", torch.__version__)' + +run_tests() { + # find all the test files that match "test*.py" + TEST_FILES="$(find test -type f -name "test*.py" | sort)" + echo "Test files are:" + echo $TEST_FILES + + echo "Executing tests:" + EXIT_STATUS=0 + for FILE in $TEST_FILES; do + # run each file on a separate process. if one fails, just keep going and + # return the final exit status. + python -m pytest -v $FILE + STATUS=$? + EXIT_STATUS="$(($EXIT_STATUS+STATUS))" + done + + echo "Done, exit status: $EXIT_STATUS" + exit $EXIT_STATUS +} + +if [[ "$RUN_FLAKE8" == "true" ]]; then + flake8 +fi + +if [[ "$SKIP_TESTS" != "true" ]]; then + echo "run_tests" + run_tests +fi + +if [[ "$RUN_EXAMPLE_TESTS" == "true" ]]; then + echo "run_example_tests" + pushd examples + ASR_MODEL_PATH=$HOME/download/data/model.pt \ + ASR_INPUT_FILE=interactive_asr/data/sample.wav \ + ASR_DATA_PATH=$HOME/download/data \ + ASR_USER_DIR=$HOME/download/fairseq/examples/speech_recognition \ + python -m unittest test/test_interactive_asr.py + popd +fi diff --git a/cmake/LoadHIP.cmake b/cmake/LoadHIP.cmake new file mode 100644 index 0000000000000000000000000000000000000000..41a537b5082e1ab647cd277d084cfbe4cea1f919 --- /dev/null +++ b/cmake/LoadHIP.cmake @@ -0,0 +1,234 @@ +set(PYTORCH_FOUND_HIP FALSE) + +if(NOT DEFINED ENV{ROCM_PATH}) + set(ROCM_PATH /opt/rocm) +else() + set(ROCM_PATH $ENV{ROCM_PATH}) +endif() + +# HIP_PATH +if(NOT DEFINED ENV{HIP_PATH}) + set(HIP_PATH ${ROCM_PATH}/hip) +else() + set(HIP_PATH $ENV{HIP_PATH}) +endif() + +if(NOT EXISTS ${HIP_PATH}) + return() +endif() + +# HCC_PATH +if(NOT DEFINED ENV{HCC_PATH}) + set(HCC_PATH ${ROCM_PATH}/hcc) +else() + set(HCC_PATH $ENV{HCC_PATH}) +endif() + +# HSA_PATH +if(NOT DEFINED ENV{HSA_PATH}) + set(HSA_PATH ${ROCM_PATH}/hsa) +else() + set(HSA_PATH $ENV{HSA_PATH}) +endif() + +# ROCBLAS_PATH +if(NOT DEFINED ENV{ROCBLAS_PATH}) + set(ROCBLAS_PATH ${ROCM_PATH}/rocblas) +else() + set(ROCBLAS_PATH $ENV{ROCBLAS_PATH}) +endif() + +# ROCFFT_PATH +if(NOT DEFINED ENV{ROCFFT_PATH}) + set(ROCFFT_PATH ${ROCM_PATH}/rocfft) +else() + set(ROCFFT_PATH $ENV{ROCFFT_PATH}) +endif() + +# HIPFFT_PATH +if(NOT DEFINED ENV{HIPFFT_PATH}) + set(HIPFFT_PATH ${ROCM_PATH}/hipfft) +else() + set(HIPFFT_PATH $ENV{HIPFFT_PATH}) +endif() + +# HIPSPARSE_PATH +if(NOT DEFINED ENV{HIPSPARSE_PATH}) + set(HIPSPARSE_PATH ${ROCM_PATH}/hipsparse) +else() + set(HIPSPARSE_PATH $ENV{HIPSPARSE_PATH}) +endif() + +# THRUST_PATH +if(DEFINED ENV{THRUST_PATH}) + set(THRUST_PATH $ENV{THRUST_PATH}) +else() + set(THRUST_PATH ${ROCM_PATH}/include) +endif() + +# HIPRAND_PATH +if(NOT DEFINED ENV{HIPRAND_PATH}) + set(HIPRAND_PATH ${ROCM_PATH}/hiprand) +else() + set(HIPRAND_PATH $ENV{HIPRAND_PATH}) +endif() + +# ROCRAND_PATH +if(NOT DEFINED ENV{ROCRAND_PATH}) + set(ROCRAND_PATH ${ROCM_PATH}/rocrand) +else() + set(ROCRAND_PATH $ENV{ROCRAND_PATH}) +endif() + +# MIOPEN_PATH +if(NOT DEFINED ENV{MIOPEN_PATH}) + set(MIOPEN_PATH ${ROCM_PATH}/miopen) +else() + set(MIOPEN_PATH $ENV{MIOPEN_PATH}) +endif() + +# RCCL_PATH +if(NOT DEFINED ENV{RCCL_PATH}) + set(RCCL_PATH ${ROCM_PATH}/rccl) +else() + set(RCCL_PATH $ENV{RCCL_PATH}) +endif() + +# ROCPRIM_PATH +if(NOT DEFINED ENV{ROCPRIM_PATH}) + set(ROCPRIM_PATH ${ROCM_PATH}/rocprim) +else() + set(ROCPRIM_PATH $ENV{ROCPRIM_PATH}) +endif() + +# HIPCUB_PATH +if(NOT DEFINED ENV{HIPCUB_PATH}) + set(HIPCUB_PATH ${ROCM_PATH}/hipcub) +else() + set(HIPCUB_PATH $ENV{HIPCUB_PATH}) +endif() + +# ROCTHRUST_PATH +if(NOT DEFINED ENV{ROCTHRUST_PATH}) + set(ROCTHRUST_PATH ${ROCM_PATH}/rocthrust) +else() + set(ROCTHRUST_PATH $ENV{ROCTHRUST_PATH}) +endif() + +# ROCTRACER_PATH +if(NOT DEFINED ENV{ROCTRACER_PATH}) + set(ROCTRACER_PATH ${ROCM_PATH}/roctracer) +else() + set(ROCTRACER_PATH $ENV{ROCTRACER_PATH}) +endif() + +if(NOT DEFINED ENV{PYTORCH_ROCM_ARCH}) + set(PYTORCH_ROCM_ARCH gfx803;gfx900;gfx906;gfx908) +else() + set(PYTORCH_ROCM_ARCH $ENV{PYTORCH_ROCM_ARCH}) +endif() + +# Add HIP to the CMAKE Module Path +set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH}) + +# Disable Asserts In Code (Can't use asserts on HIP stack.) +add_definitions(-DNDEBUG) + +macro(find_package_and_print_version PACKAGE_NAME) + find_package("${PACKAGE_NAME}" ${ARGN}) + message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}") +endmacro() + +# Find the HIP Package +find_package_and_print_version(HIP 1.0) + +if(HIP_FOUND) + set(PYTORCH_FOUND_HIP TRUE) + + # Find ROCM version for checks + file(READ "${ROCM_PATH}/.info/version-dev" ROCM_VERSION_DEV_RAW) + string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_DEV_MATCH ${ROCM_VERSION_DEV_RAW}) + if(ROCM_VERSION_DEV_MATCH) + set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1}) + set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2}) + set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3}) + set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}") + endif() + message("\n***** ROCm version from ${ROCM_PATH}/.info/version-dev ****\n") + message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}") + message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}") + message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}") + message("ROCM_VERSION_DEV_PATCH: ${ROCM_VERSION_DEV_PATCH}") + + message("\n***** Library versions from dpkg *****\n") + execute_process(COMMAND dpkg -l COMMAND grep rocm-dev COMMAND awk "{print $2 \" VERSION: \" $3}") + execute_process(COMMAND dpkg -l COMMAND grep rocm-libs COMMAND awk "{print $2 \" VERSION: \" $3}") + execute_process(COMMAND dpkg -l COMMAND grep hsakmt-roct COMMAND awk "{print $2 \" VERSION: \" $3}") + execute_process(COMMAND dpkg -l COMMAND grep rocr-dev COMMAND awk "{print $2 \" VERSION: \" $3}") + execute_process(COMMAND dpkg -l COMMAND grep -w hcc COMMAND awk "{print $2 \" VERSION: \" $3}") + execute_process(COMMAND dpkg -l COMMAND grep hip_base COMMAND awk "{print $2 \" VERSION: \" $3}") + execute_process(COMMAND dpkg -l COMMAND grep hip_hcc COMMAND awk "{print $2 \" VERSION: \" $3}") + + message("\n***** Library versions from cmake find_package *****\n") + + set(CMAKE_HCC_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG}) + set(CMAKE_HCC_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE}) + ### Remove setting of Flags when FindHIP.CMake PR #558 is accepted.### + + set(hip_DIR ${HIP_PATH}/lib/cmake/hip) + set(hsa-runtime64_DIR ${ROCM_PATH}/lib/cmake/hsa-runtime64) + set(AMDDeviceLibs_DIR ${ROCM_PATH}/lib/cmake/AMDDeviceLibs) + set(amd_comgr_DIR ${ROCM_PATH}/lib/cmake/amd_comgr) + set(rocrand_DIR ${ROCRAND_PATH}/lib/cmake/rocrand) + set(hiprand_DIR ${HIPRAND_PATH}/lib/cmake/hiprand) + set(rocblas_DIR ${ROCBLAS_PATH}/lib/cmake/rocblas) + set(miopen_DIR ${MIOPEN_PATH}/lib/cmake/miopen) + set(rocfft_DIR ${ROCFFT_PATH}/lib/cmake/rocfft) + set(hipfft_DIR ${HIPFFT_PATH}/lib/cmake/hipfft) + set(hipsparse_DIR ${HIPSPARSE_PATH}/lib/cmake/hipsparse) + set(rccl_DIR ${RCCL_PATH}/lib/cmake/rccl) + set(rocprim_DIR ${ROCPRIM_PATH}/lib/cmake/rocprim) + set(hipcub_DIR ${HIPCUB_PATH}/lib/cmake/hipcub) + set(rocthrust_DIR ${ROCTHRUST_PATH}/lib/cmake/rocthrust) + + find_package_and_print_version(hip REQUIRED) + find_package_and_print_version(hsa-runtime64 REQUIRED) + find_package_and_print_version(amd_comgr REQUIRED) + find_package_and_print_version(rocrand REQUIRED) + find_package_and_print_version(hiprand REQUIRED) + find_package_and_print_version(rocblas REQUIRED) + find_package_and_print_version(miopen REQUIRED) + find_package_and_print_version(rocfft REQUIRED) + #if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "4.1.0") + find_package_and_print_version(hipfft REQUIRED) + #endif() + find_package_and_print_version(hipsparse REQUIRED) + find_package_and_print_version(rccl) + find_package_and_print_version(rocprim REQUIRED) + find_package_and_print_version(hipcub REQUIRED) + find_package_and_print_version(rocthrust REQUIRED) + + if(HIP_COMPILER STREQUAL clang) + set(hip_library_name amdhip64) + else() + set(hip_library_name hip_hcc) + endif() + message("HIP library name: ${hip_library_name}") + + # TODO: hip_hcc has an interface include flag "-hc" which is only + # recognizable by hcc, but not gcc and clang. Right now in our + # setup, hcc is only used for linking, but it should be used to + # compile the *_hip.cc files as well. + find_library(PYTORCH_HIP_HCC_LIBRARIES ${hip_library_name} HINTS ${HIP_PATH}/lib) + # TODO: miopen_LIBRARIES should return fullpath to the library file, + # however currently it's just the lib name + find_library(PYTORCH_MIOPEN_LIBRARIES ${miopen_LIBRARIES} HINTS ${MIOPEN_PATH}/lib) + # TODO: rccl_LIBRARIES should return fullpath to the library file, + # however currently it's just the lib name + find_library(PYTORCH_RCCL_LIBRARIES ${rccl_LIBRARIES} HINTS ${RCCL_PATH}/lib) + # hiprtc is part of HIP + find_library(ROCM_HIPRTC_LIB ${hip_library_name} HINTS ${HIP_PATH}/lib) + # roctx is part of roctracer + find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCTRACER_PATH}/lib) + set(roctracer_INCLUDE_DIRS ${ROCTRACER_PATH}/include) +endif() diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..76b604393a23d02a1aa5e092fa227be01081260f --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,27 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = -W # converts warnings into error +SPHINXBUILD = sphinx-build +SPHINXPROJ = torchaudio +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +docset: html + doc2dash --name $(SPHINXPROJ) --icon $(SOURCEDIR)/_static/img/pytorch-logo-flame.png --enable-js --online-redirect-url http://pytorch.org/audio/ --force $(BUILDDIR)/html/ + + # Manually fix because Zeal doesn't deal well with `icon.png`-only at 2x resolution. + cp $(SPHINXPROJ).docset/icon.png $(SPHINXPROJ).docset/icon@2x.png + convert $(SPHINXPROJ).docset/icon@2x.png -resize 16x16 $(SPHINXPROJ).docset/icon.png + +.PHONY: help Makefile docset + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000000000000000000000000000000000000..224ce5d7ee8ef71d9d4f92abfdd0881f97f4c93b --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,36 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build +set SPHINXPROJ=torchaudio + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% + +:end +popd diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..7bb247e114eeadc19fd6ecf4f061830af438a6a0 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,5 @@ +sphinx==3.5.4 +-e git+git://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme +sphinxcontrib.katex +sphinxcontrib.bibtex +matplotlib diff --git a/docs/source/_static/img/pytorch-logo-dark.png b/docs/source/_static/img/pytorch-logo-dark.png new file mode 100644 index 0000000000000000000000000000000000000000..b7a1ceb964af782b8a453b3eb6f8eb82b7ddbd49 Binary files /dev/null and b/docs/source/_static/img/pytorch-logo-dark.png differ diff --git a/docs/source/_static/img/pytorch-logo-dark.svg b/docs/source/_static/img/pytorch-logo-dark.svg new file mode 100644 index 0000000000000000000000000000000000000000..5e5300038589af8b9a88c09834c2fae7feb6f389 --- /dev/null +++ b/docs/source/_static/img/pytorch-logo-dark.svg @@ -0,0 +1,33 @@ + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/_static/img/pytorch-logo-flame.png b/docs/source/_static/img/pytorch-logo-flame.png new file mode 100644 index 0000000000000000000000000000000000000000..bad49bf30b4afe5bb34f76d73a03e39fc48bbedb Binary files /dev/null and b/docs/source/_static/img/pytorch-logo-flame.png differ diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html new file mode 100644 index 0000000000000000000000000000000000000000..8727b6ce5dcaa3f722552a2df82f9014752efd2d --- /dev/null +++ b/docs/source/_templates/layout.html @@ -0,0 +1,8 @@ +{% extends "!layout.html" %} + +{% block sidebartitle %} + + {% include "searchbox.html" %} +{% endblock %} diff --git a/docs/source/backend.rst b/docs/source/backend.rst new file mode 100644 index 0000000000000000000000000000000000000000..6eda225307ac460e67f048dfb049d2827c42032c --- /dev/null +++ b/docs/source/backend.rst @@ -0,0 +1,92 @@ +.. _backend: + +torchaudio.backend +================== + +Overview +~~~~~~~~ + +:mod:`torchaudio.backend` module provides implementations for audio file I/O functionalities, which are ``torchaudio.info``, ``torchaudio.load``, and ``torchaudio.save``. + +There are currently four implementations available. + +* :ref:`"sox_io" ` (default on Linux/macOS) +* :ref:`"soundfile" ` (default on Windows) + +.. note:: + Instead of calling functions in ``torchaudio.backend`` directly, please use ``torchaudio.info``, ``torchaudio.load``, and ``torchaudio.save`` with proper backend set with :func:`torchaudio.set_audio_backend`. + +Availability +------------ + +``"sox_io"`` backend requires C++ extension module, which is included in Linux/macOS binary distributions. This backend is not available on Windows. + +``"soundfile"`` backend requires ``SoundFile``. Please refer to `the SoundFile documentation `_ for the installation. + +Common Data Structure +~~~~~~~~~~~~~~~~~~~~~ + +Structures used to report the metadata of audio files. + +AudioMetaData +------------- + +.. autoclass:: torchaudio.backend.common.AudioMetaData + +.. _sox_io_backend: + +Sox IO Backend +~~~~~~~~~~~~~~ + +The ``"sox_io"`` backend is available and default on Linux/macOS and not available on Windows. + +I/O functions of this backend support `TorchScript `_. + +You can switch from another backend to the ``sox_io`` backend with the following; + +.. code:: + + torchaudio.set_audio_backend("sox_io") + +info +---- + +.. autofunction:: torchaudio.backend.sox_io_backend.info + +load +---- + +.. autofunction:: torchaudio.backend.sox_io_backend.load + +save +---- + +.. autofunction:: torchaudio.backend.sox_io_backend.save + +.. _soundfile_backend: + +Soundfile Backend +~~~~~~~~~~~~~~~~~ + +The ``"soundfile"`` backend is available when `SoundFile `_ is installed. This backend is the default on Windows. + +You can switch from another backend to the ``"soundfile"`` backend with the following; + +.. code:: + + torchaudio.set_audio_backend("soundfile") + +info +---- + +.. autofunction:: torchaudio.backend.soundfile_backend.info + +load +---- + +.. autofunction:: torchaudio.backend.soundfile_backend.load + +save +---- + +.. autofunction:: torchaudio.backend.soundfile_backend.save diff --git a/docs/source/compliance.kaldi.rst b/docs/source/compliance.kaldi.rst new file mode 100644 index 0000000000000000000000000000000000000000..72827ca3fbf2ab09c44019ab0a426afb748df36e --- /dev/null +++ b/docs/source/compliance.kaldi.rst @@ -0,0 +1,31 @@ +.. role:: hidden + :class: hidden-section + +torchaudio.compliance.kaldi +============================ + +.. currentmodule:: torchaudio.compliance.kaldi + +The useful processing operations of kaldi_ can be performed with torchaudio. +Various functions with identical parameters are given so that torchaudio can +produce similar outputs. + +.. _kaldi: https://github.com/kaldi-asr/kaldi + +Functions +--------- + +:hidden:`spectrogram` +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: spectrogram + +:hidden:`fbank` +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: fbank + +:hidden:`mfcc` +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: mfcc diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..9bed75a7f6b039f09b228e2973ad6ac77a0712d1 --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# PyTorch documentation build configuration file, created by +# sphinx-quickstart on Fri Dec 23 13:31:47 2016. +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +# import os +# import sys +# sys.path.insert(0, os.path.abspath('.')) +import pytorch_sphinx_theme + +# -- General configuration ------------------------------------------------ + +# If your documentation needs a minimal Sphinx version, state it here. +# +needs_sphinx = '1.6' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.autosummary', + 'sphinx.ext.doctest', + 'sphinx.ext.intersphinx', + 'sphinx.ext.todo', + 'sphinx.ext.coverage', + 'sphinx.ext.napoleon', + 'sphinx.ext.viewcode', + 'sphinxcontrib.katex', + 'sphinxcontrib.bibtex', +] + +# katex options +# +# + +katex_options = r''' +delimiters : [ + {left: "$$", right: "$$", display: true}, + {left: "\\(", right: "\\)", display: false}, + {left: "\\[", right: "\\]", display: true} +] +''' + +bibtex_bibfiles = ['refs.bib'] + +napoleon_use_ivar = True +napoleon_numpy_docstring = False +napoleon_google_docstring = True + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +# source_suffix = ['.rst', '.md'] +source_suffix = '.rst' + +# The master toctree document. +master_doc = 'index' + +# General information about the project. +project = 'Torchaudio' +copyright = '2018, Torchaudio Contributors' +author = 'Torchaudio Contributors' + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +# TODO: change to [:2] at v1.0 +version = '0.10.0 ' +# The full version, including alpha/beta/rc tags. +# TODO: verify this works as expected +release = '0.10.0' + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = None + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This patterns also effect to html_static_path and html_extra_path +exclude_patterns = [] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = True + + +# -- Options for HTML output ---------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'pytorch_sphinx_theme' +html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# +html_theme_options = { + 'pytorch_project': 'audio', + 'collapse_navigation': False, + 'display_version': True, + 'logo_only': True, + 'navigation_with_keys': True, + 'analytics_id': 'UA-117752657-2', +} + +html_logo = '_static/img/pytorch-logo-dark.svg' + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] +html_css_files = [ + 'https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css' +] + +# -- Options for HTMLHelp output ------------------------------------------ + +# Output file base name for HTML help builder. +htmlhelp_basename = 'TorchAudiodoc' + + +# -- Options for LaTeX output --------------------------------------------- + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + (master_doc, 'pytorch.tex', 'Torchaudio Documentation', + 'Torch Contributors', 'manual'), +] + + +# -- Options for manual page output --------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + (master_doc, 'Torchaudio', 'Torchaudio Documentation', + [author], 1) +] + + +# -- Options for Texinfo output ------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + (master_doc, 'Torchaudio', 'Torchaudio Documentation', + author, 'Torchaudio', 'Load audio files into pytorch tensors.', + 'Miscellaneous'), +] + + +# Example configuration for intersphinx: refer to the Python standard library. +intersphinx_mapping = { + 'python': ('https://docs.python.org/', None), + 'numpy': ('https://docs.scipy.org/doc/numpy/', None), + 'torch': ('https://pytorch.org/docs/stable/', None), +} + +# -- A patch that prevents Sphinx from cross-referencing ivar tags ------- +# See http://stackoverflow.com/a/41184353/3343043 + +from docutils import nodes +from sphinx.util.docfields import TypedField +from sphinx import addnodes + + +def patched_make_field(self, types, domain, items, **kw): + # `kw` catches `env=None` needed for newer sphinx while maintaining + # backwards compatibility when passed along further down! + + # type: (list, str, tuple) -> nodes.field + def handle_item(fieldarg, content): + par = nodes.paragraph() + par += addnodes.literal_strong('', fieldarg) # Patch: this line added + # par.extend(self.make_xrefs(self.rolename, domain, fieldarg, + # addnodes.literal_strong)) + if fieldarg in types: + par += nodes.Text(' (') + # NOTE: using .pop() here to prevent a single type node to be + # inserted twice into the doctree, which leads to + # inconsistencies later when references are resolved + fieldtype = types.pop(fieldarg) + if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text): + typename = u''.join(n.astext() for n in fieldtype) + typename = typename.replace('int', 'python:int') + typename = typename.replace('long', 'python:long') + typename = typename.replace('float', 'python:float') + typename = typename.replace('type', 'python:type') + par.extend(self.make_xrefs(self.typerolename, domain, typename, + addnodes.literal_emphasis, **kw)) + else: + par += fieldtype + par += nodes.Text(')') + par += nodes.Text(' -- ') + par += content + return par + + fieldname = nodes.field_name('', self.label) + if len(items) == 1 and self.can_collapse: + fieldarg, content = items[0] + bodynode = handle_item(fieldarg, content) + else: + bodynode = self.list_type() + for fieldarg, content in items: + bodynode += nodes.list_item('', handle_item(fieldarg, content)) + fieldbody = nodes.field_body('', bodynode) + return nodes.field('', fieldname, fieldbody) + +TypedField.make_field = patched_make_field diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst new file mode 100644 index 0000000000000000000000000000000000000000..8189bb82a4bf33518226c63809acc6d394a83493 --- /dev/null +++ b/docs/source/datasets.rst @@ -0,0 +1,121 @@ +torchaudio.datasets +==================== + +All datasets are subclasses of :class:`torch.utils.data.Dataset` +and have ``__getitem__`` and ``__len__`` methods implemented. +Hence, they can all be passed to a :class:`torch.utils.data.DataLoader` +which can load multiple samples parallelly using ``torch.multiprocessing`` workers. +For example: :: + + yesno_data = torchaudio.datasets.YESNO('.', download=True) + data_loader = torch.utils.data.DataLoader(yesno_data, + batch_size=1, + shuffle=True, + num_workers=args.nThreads) + +The following datasets are available: + +.. contents:: Datasets + :local: + +All the datasets have almost similar API. They all have two common arguments: +``transform`` and ``target_transform`` to transform the input and target respectively. + + +.. currentmodule:: torchaudio.datasets + + +CMUARCTIC +~~~~~~~~~ + +.. autoclass:: CMUARCTIC + :members: + :special-members: __getitem__ + + +CMUDict +~~~~~~~~~ + +.. autoclass:: CMUDict + :members: + :special-members: __getitem__ + + +COMMONVOICE +~~~~~~~~~~~ + +.. autoclass:: COMMONVOICE + :members: + :special-members: __getitem__ + + +GTZAN +~~~~~ + +.. autoclass:: GTZAN + :members: + :special-members: __getitem__ + + +LIBRISPEECH +~~~~~~~~~~~ + +.. autoclass:: LIBRISPEECH + :members: + :special-members: __getitem__ + + +LIBRITTS +~~~~~~~~ + +.. autoclass:: LIBRITTS + :members: + :special-members: __getitem__ + + +LJSPEECH +~~~~~~~~ + +.. autoclass:: LJSPEECH + :members: + :special-members: __getitem__ + + +SPEECHCOMMANDS +~~~~~~~~~~~~~~ + +.. autoclass:: SPEECHCOMMANDS + :members: + :special-members: __getitem__ + + +TEDLIUM +~~~~~~~~~~~~~~ + +.. autoclass:: TEDLIUM + :members: + :special-members: __getitem__ + + +VCTK +~~~~ + +.. autoclass:: VCTK + :members: + :special-members: __getitem__ + + +VCTK_092 +~~~~~~~~ + +.. autoclass:: VCTK_092 + :members: + :special-members: __getitem__ + + +YESNO +~~~~~ + +.. autoclass:: YESNO + :members: + :special-members: __getitem__ diff --git a/docs/source/functional.rst b/docs/source/functional.rst new file mode 100644 index 0000000000000000000000000000000000000000..5ff8ec695e97c872dd9430c1edbf3144bc4f6cfc --- /dev/null +++ b/docs/source/functional.rst @@ -0,0 +1,281 @@ +.. role:: hidden + :class: hidden-section + +torchaudio.functional +===================== + +.. currentmodule:: torchaudio.functional + +Functions to perform common audio operations. + +:hidden:`Utility` +~~~~~~~~~~~~~~~~~ + +amplitude_to_DB +--------------- + +.. autofunction:: amplitude_to_DB + +DB_to_amplitude +--------------- + +.. autofunction:: DB_to_amplitude + +create_fb_matrix +---------------- + +.. autofunction:: create_fb_matrix + +melscale_fbanks +--------------- + +.. autofunction:: melscale_fbanks + +linear_fbanks +------------- + +.. autofunction:: linear_fbanks + +create_dct +---------- + +.. autofunction:: create_dct + +mask_along_axis +--------------- + +.. autofunction:: mask_along_axis + +mask_along_axis_iid +------------------- + +.. autofunction:: mask_along_axis_iid + +mu_law_encoding +--------------- + +.. autofunction:: mu_law_encoding + +mu_law_decoding +--------------- + +.. autofunction:: mu_law_decoding + +apply_codec +----------- + +.. autofunction:: apply_codec + +resample +-------- + +.. autofunction:: resample + +:hidden:`Complex Utility` +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Utilities for pseudo complex tensor. This is not for the native complex dtype, such as `cfloat64`, but for tensors with real-value type and have extra dimension at the end for real and imaginary parts. + +angle +----- + +.. autofunction:: angle + +complex_norm +------------ + +.. autofunction:: complex_norm + +magphase +-------- + +.. autofunction:: magphase + +:hidden:`Filtering` +~~~~~~~~~~~~~~~~~~~ + + +allpass_biquad +-------------- + +.. autofunction:: allpass_biquad + +band_biquad +----------- + +.. autofunction:: band_biquad + +bandpass_biquad +--------------- + +.. autofunction:: bandpass_biquad + +bandreject_biquad +----------------- + +.. autofunction:: bandreject_biquad + +bass_biquad +----------- + +.. autofunction:: bass_biquad + +biquad +------ + +.. autofunction:: biquad + +contrast +-------- + +.. autofunction:: contrast + +dcshift +------- + +.. autofunction:: dcshift + +deemph_biquad +------------- + +.. autofunction:: deemph_biquad + + +dither +------ + +.. autofunction:: dither + +equalizer_biquad +---------------- + +.. autofunction:: equalizer_biquad + +filtfilt +-------- + +.. autofunction:: filtfilt + +flanger +------- + +.. autofunction:: flanger + +gain +---- + +.. autofunction:: gain + +highpass_biquad +--------------- + +.. autofunction:: highpass_biquad + +lfilter +------- + +.. autofunction:: lfilter + +lowpass_biquad +-------------- + +.. autofunction:: lowpass_biquad + +overdrive +--------- + +.. autofunction:: overdrive + +phaser +------ + +.. autofunction:: phaser + +riaa_biquad +----------- + +.. autofunction:: riaa_biquad + +treble_biquad +------------- + +.. autofunction:: treble_biquad + +:hidden:`Feature Extractions` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +:hidden:`vad` +------------- + +.. autofunction:: vad + +:hidden:`spectrogram` +--------------------- + +.. autofunction:: spectrogram + +:hidden:`inverse_spectrogram` +----------------------------- + +.. autofunction:: inverse_spectrogram + +:hidden:`griffinlim` +-------------------- + +.. autofunction:: griffinlim + +:hidden:`phase_vocoder` +----------------------- + +.. autofunction:: phase_vocoder + +:hidden:`pitch_shift` +--------------------- + +.. autofunction:: pitch_shift + +:hidden:`compute_deltas` +------------------------ + +.. autofunction:: compute_deltas + +:hidden:`detect_pitch_frequency` +-------------------------------- + +.. autofunction:: detect_pitch_frequency + +:hidden:`sliding_window_cmn` +---------------------------- + +.. autofunction:: sliding_window_cmn + +:hidden:`compute_kaldi_pitch` +----------------------------- + +.. autofunction:: compute_kaldi_pitch + +:hidden:`spectral_centroid` +--------------------------- + +.. autofunction:: spectral_centroid + +:hidden:`Loss` +~~~~~~~~~~~~~~ + +rnnt_loss +--------- + +.. autofunction:: rnnt_loss + +:hidden:`Metric` +~~~~~~~~~~~~~~~~ + +edit_distance +------------- + +.. autofunction:: edit_distance + +References +~~~~~~~~~~ + +.. footbibliography:: diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..89c9da5ce4a83d2436d9ab24e2e43162d79b43a0 --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,55 @@ +torchaudio +========== +This library is part of the `PyTorch +`_ project. PyTorch is an open source +machine learning framework. + +Features described in this documentation are classified by release status: + + *Stable:* These features will be maintained long-term and there should generally + be no major performance limitations or gaps in documentation. + We also expect to maintain backwards compatibility (although + breaking changes can happen and notice will be given one release ahead + of time). + + *Beta:* Features are tagged as Beta because the API may change based on + user feedback, because the performance needs to improve, or because + coverage across operators is not yet complete. For Beta features, we are + committing to seeing the feature through to the Stable classification. + We are not, however, committing to backwards compatibility. + + *Prototype:* These features are typically not available as part of + binary distributions like PyPI or Conda, except sometimes behind run-time + flags, and are at an early stage for feedback and testing. + + +The :mod:`torchaudio` package consists of I/O, popular datasets and common audio transformations. + +.. toctree:: + :maxdepth: 2 + :caption: Package Reference + + torchaudio + backend + functional + transforms + datasets + models + pipelines + sox_effects + compliance.kaldi + kaldi_io + utils + + +.. toctree:: + :maxdepth: 1 + :caption: PyTorch Libraries + + PyTorch + torchaudio + torchtext + torchvision + TorchElastic + TorchServe + PyTorch on XLA Devices diff --git a/docs/source/kaldi_io.rst b/docs/source/kaldi_io.rst new file mode 100644 index 0000000000000000000000000000000000000000..2744bcc89711d58c679e673c351a6efd1092980b --- /dev/null +++ b/docs/source/kaldi_io.rst @@ -0,0 +1,43 @@ +.. role:: hidden + :class: hidden-section + +torchaudio.kaldi_io +====================== + +.. currentmodule:: torchaudio.kaldi_io + +To use this module, the dependency kaldi_io_ needs to be installed. +This is a light wrapper around ``kaldi_io`` that returns :class:`torch.Tensor`. + +.. _kaldi_io: https://github.com/vesis84/kaldi-io-for-python + +Vectors +------- + +:hidden:`read_vec_int_ark` +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: read_vec_int_ark + +:hidden:`read_vec_flt_scp` +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: read_vec_flt_scp + +:hidden:`read_vec_flt_ark` +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: read_vec_flt_ark + +Matrices +-------- + +:hidden:`read_mat_scp` +~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: read_mat_scp + +:hidden:`read_mat_ark` +~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: read_mat_ark diff --git a/docs/source/models.rst b/docs/source/models.rst new file mode 100644 index 0000000000000000000000000000000000000000..70738f8d768cf2b3df282bdde8ee3dcc507d7fcd --- /dev/null +++ b/docs/source/models.rst @@ -0,0 +1,128 @@ +.. role:: hidden + :class: hidden-section + +torchaudio.models +================= + +.. currentmodule:: torchaudio.models + +The models subpackage contains definitions of models for addressing common audio tasks. + + +ConvTasNet +~~~~~~~~~~ + +.. autoclass:: ConvTasNet + + .. automethod:: forward + + +DeepSpeech +~~~~~~~~~~ + +.. autoclass:: DeepSpeech + + .. automethod:: forward + + +Tacotron2 +~~~~~~~~~ + +.. autoclass:: Tacotron2 + + .. automethod:: forward + + .. automethod:: infer + +Wav2Letter +~~~~~~~~~~ + +.. autoclass:: Wav2Letter + + .. automethod:: forward + + +Wav2Vec2.0 / HuBERT +~~~~~~~~~~~~~~~~~~~ + +Model +----- + +Wav2Vec2Model +^^^^^^^^^^^^^ + +.. autoclass:: Wav2Vec2Model + + .. automethod:: extract_features + + .. automethod:: forward + +Factory Functions +----------------- + +wav2vec2_model +^^^^^^^^^^^^^^ + +.. autofunction:: wav2vec2_model + + +wav2vec2_base +^^^^^^^^^^^^^ + +.. autofunction:: wav2vec2_base + +wav2vec2_large +^^^^^^^^^^^^^^ + +.. autofunction:: wav2vec2_large + +wav2vec2_large_lv60k +^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: wav2vec2_large_lv60k + +hubert_base +^^^^^^^^^^^ + +.. autofunction:: hubert_base + +hubert_large +^^^^^^^^^^^^ + +.. autofunction:: hubert_large + +hubert_xlarge +^^^^^^^^^^^^^ + +.. autofunction:: hubert_xlarge + +Utility Functions +----------------- + +.. currentmodule:: torchaudio.models.wav2vec2.utils + +import_huggingface_model +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: import_huggingface_model + +import_fairseq_model +^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: import_fairseq_model + +.. currentmodule:: torchaudio.models + +WaveRNN +~~~~~~~ + +.. autoclass:: WaveRNN + + .. automethod:: forward + + .. automethod:: infer + +References +~~~~~~~~~~ + +.. footbibliography:: diff --git a/docs/source/pipelines.rst b/docs/source/pipelines.rst new file mode 100644 index 0000000000000000000000000000000000000000..962eae9f340489b2f5274d88a3ca785dc861d48f --- /dev/null +++ b/docs/source/pipelines.rst @@ -0,0 +1,238 @@ +torchaudio.pipelines +==================== + +.. currentmodule:: torchaudio.pipelines + +The pipelines subpackage contains API to access the models with pretrained weights, and information/helper functions associated the pretrained weights. + +wav2vec 2.0 / HuBERT - Representation Learning +---------------------------------------------- + +.. autoclass:: Wav2Vec2Bundle + :members: sample_rate + + .. automethod:: get_model + +WAV2VEC2_BASE +~~~~~~~~~~~~~ + +.. container:: py attribute + + .. autodata:: WAV2VEC2_BASE + :no-value: + +WAV2VEC2_LARGE +~~~~~~~~~~~~~~ + +.. container:: py attribute + + .. autodata:: WAV2VEC2_LARGE + :no-value: + +WAV2VEC2_LARGE_LV60K +~~~~~~~~~~~~~~~~~~~~ + +.. container:: py attribute + + .. autodata:: WAV2VEC2_LARGE_LV60K + :no-value: + + +WAV2VEC2_XLSR53 +~~~~~~~~~~~~~~~ + +.. container:: py attribute + + .. autodata:: WAV2VEC2_XLSR53 + :no-value: + +HUBERT_BASE +~~~~~~~~~~~ + +.. container:: py attribute + + .. autodata:: HUBERT_BASE + :no-value: + +HUBERT_LARGE +~~~~~~~~~~~~ + +.. container:: py attribute + + .. autodata:: HUBERT_LARGE + :no-value: + +HUBERT_XLARGE +~~~~~~~~~~~~~ + +.. container:: py attribute + + .. autodata:: HUBERT_XLARGE + :no-value: + +wav2vec 2.0 / HuBERT - Fine-tuned ASR +------------------------------------- + +.. autoclass:: Wav2Vec2ASRBundle + :members: sample_rate + + .. automethod:: get_model + + .. automethod:: get_labels + + +WAV2VEC2_ASR_BASE_10M +~~~~~~~~~~~~~~~~~~~~~ + +.. container:: py attribute + + .. autodata:: WAV2VEC2_ASR_BASE_10M + :no-value: + +WAV2VEC2_ASR_BASE_100H +~~~~~~~~~~~~~~~~~~~~~~ + +.. container:: py attribute + + .. autodata:: WAV2VEC2_ASR_BASE_100H + :no-value: + +WAV2VEC2_ASR_BASE_960H +~~~~~~~~~~~~~~~~~~~~~~ + +.. container:: py attribute + + .. autodata:: WAV2VEC2_ASR_BASE_960H + :no-value: + +WAV2VEC2_ASR_LARGE_10M +~~~~~~~~~~~~~~~~~~~~~~ + +.. container:: py attribute + + .. autodata:: WAV2VEC2_ASR_LARGE_10M + :no-value: + +WAV2VEC2_ASR_LARGE_100H +~~~~~~~~~~~~~~~~~~~~~~~ + +.. container:: py attribute + + .. autodata:: WAV2VEC2_ASR_LARGE_100H + :no-value: + +WAV2VEC2_ASR_LARGE_960H +~~~~~~~~~~~~~~~~~~~~~~~ + +.. container:: py attribute + + .. autodata:: WAV2VEC2_ASR_LARGE_960H + :no-value: + +WAV2VEC2_ASR_LARGE_LV60K_10M +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. container:: py attribute + + .. autodata:: WAV2VEC2_ASR_LARGE_LV60K_10M + :no-value: + +WAV2VEC2_ASR_LARGE_LV60K_100H +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. container:: py attribute + + .. autodata:: WAV2VEC2_ASR_LARGE_LV60K_100H + :no-value: + +WAV2VEC2_ASR_LARGE_LV60K_960H +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. container:: py attribute + + .. autodata:: WAV2VEC2_ASR_LARGE_LV60K_960H + :no-value: + +HUBERT_ASR_LARGE +~~~~~~~~~~~~~~~~ + +.. container:: py attribute + + .. autodata:: HUBERT_ASR_LARGE + :no-value: + +HUBERT_ASR_XLARGE +~~~~~~~~~~~~~~~~~ + +.. container:: py attribute + + .. autodata:: HUBERT_ASR_XLARGE + :no-value: + +Tacotron2 Text-To-Speech +------------------------ + +Tacotron2TTSBundle +~~~~~~~~~~~~~~~~~~ + +.. autoclass:: Tacotron2TTSBundle + + .. automethod:: get_text_processor + + .. automethod:: get_tacotron2 + + .. automethod:: get_vocoder + +Tacotron2TTSBundle - TextProcessor +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: torchaudio.pipelines::Tacotron2TTSBundle.TextProcessor + :members: tokens + :special-members: __call__ + + +Tacotron2TTSBundle - Vocoder +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: torchaudio.pipelines::Tacotron2TTSBundle.Vocoder + :members: sample_rate + :special-members: __call__ + + +TACOTRON2_WAVERNN_PHONE_LJSPEECH +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. container:: py attribute + + .. autodata:: TACOTRON2_WAVERNN_PHONE_LJSPEECH + :no-value: + + +TACOTRON2_WAVERNN_CHAR_LJSPEECH +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. container:: py attribute + + .. autodata:: TACOTRON2_WAVERNN_CHAR_LJSPEECH + :no-value: + +TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. container:: py attribute + + .. autodata:: TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH + :no-value: + +TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. container:: py attribute + + .. autodata:: TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH + :no-value: + +References +---------- + +.. footbibliography:: diff --git a/docs/source/refs.bib b/docs/source/refs.bib new file mode 100644 index 0000000000000000000000000000000000000000..c76513de85e8631aa8d89edb89825e38def9a40d --- /dev/null +++ b/docs/source/refs.bib @@ -0,0 +1,216 @@ +@article{specaugment, + title={SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition}, + url={http://dx.doi.org/10.21437/Interspeech.2019-2680}, + DOI={10.21437/interspeech.2019-2680}, + journal={Interspeech 2019}, + publisher={ISCA}, + author={Park, Daniel S. and Chan, William and Zhang, Yu and Chiu, Chung-Cheng and Zoph, Barret and Cubuk, Ekin D. and Le, Quoc V.}, + year={2019}, + month={Sep} +} +@misc{ljspeech17, + author = {Keith Ito and Linda Johnson}, + title = {The LJ Speech Dataset}, + howpublished = {\url{https://keithito.com/LJ-Speech-Dataset/}}, + year = {2017} +} +@misc{conneau2020unsupervised, + title={Unsupervised Cross-lingual Representation Learning for Speech Recognition}, + author={Alexis Conneau and Alexei Baevski and Ronan Collobert and Abdelrahman Mohamed and Michael Auli}, + year={2020}, + eprint={2006.13979}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +@inproceedings{Gales2014SpeechRA, + title={Speech recognition and keyword spotting for low-resource languages: Babel project research at CUED}, + author={Mark John Francis Gales and Kate Knill and Anton Ragni and Shakti Prasad Rath}, + booktitle={SLTU}, + year={2014} +} +@misc{ardila2020common, + title={Common Voice: A Massively-Multilingual Speech Corpus}, + author={Rosana Ardila and Megan Branson and Kelly Davis and Michael Henretty and Michael Kohler and Josh Meyer and Reuben Morais and Lindsay Saunders and Francis M. Tyers and Gregor Weber}, + year={2020}, + eprint={1912.06670}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +@article{Pratap_2020, + title={MLS: A Large-Scale Multilingual Dataset for Speech Research}, + url={http://dx.doi.org/10.21437/Interspeech.2020-2826}, + DOI={10.21437/interspeech.2020-2826}, + journal={Interspeech 2020}, + publisher={ISCA}, + author={Pratap, Vineel and Xu, Qiantong and Sriram, Anuroop and Synnaeve, Gabriel and Collobert, Ronan}, + year={2020}, + month={Oct} +} +@INPROCEEDINGS{librilight, + author={J. {Kahn} and M. {Rivière} and W. {Zheng} and E. {Kharitonov} and Q. {Xu} and P. E. {Mazaré} and J. {Karadayi} and V. {Liptchinsky} and R. {Collobert} and C. {Fuegen} and T. {Likhomanenko} and G. {Synnaeve} and A. {Joulin} and A. {Mohamed} and E. {Dupoux}}, + booktitle={ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, + title={Libri-Light: A Benchmark for ASR with Limited or No Supervision}, + year={2020}, + pages={7669-7673}, + note = {\url{https://github.com/facebookresearch/libri-light}}, +} +@INPROCEEDINGS{7178964, + author={Panayotov, Vassil and Chen, Guoguo and Povey, Daniel and Khudanpur, Sanjeev}, + booktitle={2015 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, + title={Librispeech: An ASR corpus based on public domain audio books}, + year={2015}, + volume={}, + number={}, + pages={5206-5210}, + doi={10.1109/ICASSP.2015.7178964} +} +@inproceedings{ott2019fairseq, + title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling}, + author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli}, + booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations}, + year = {2019}, +} +@misc{baevski2020wav2vec, + title={wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations}, + author={Alexei Baevski and Henry Zhou and Abdelrahman Mohamed and Michael Auli}, + year={2020}, + eprint={2006.11477}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +@misc{hsu2021hubert, + title={HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units}, + author={Wei-Ning Hsu and Benjamin Bolte and Yao-Hung Hubert Tsai and Kushal Lakhotia and Ruslan Salakhutdinov and Abdelrahman Mohamed}, + year={2021}, + eprint={2106.07447}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +@misc{hannun2014deep, + title={Deep Speech: Scaling up end-to-end speech recognition}, + author={Awni Hannun and Carl Case and Jared Casper and Bryan Catanzaro and Greg Diamos and Erich Elsen and Ryan Prenger and Sanjeev Satheesh and Shubho Sengupta and Adam Coates and Andrew Y. Ng}, + year={2014}, + eprint={1412.5567}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +@misc{graves2012sequence, + title={Sequence Transduction with Recurrent Neural Networks}, + author={Alex Graves}, + year={2012}, + eprint={1211.3711}, + archivePrefix={arXiv}, + primaryClass={cs.NE} +} +@misc{collobert2016wav2letter, + title={Wav2Letter: an End-to-End ConvNet-based Speech Recognition System}, + author={Ronan Collobert and Christian Puhrsch and Gabriel Synnaeve}, + year={2016}, + eprint={1609.03193}, + archivePrefix={arXiv}, + primaryClass={cs.LG} +} +@misc{kalchbrenner2018efficient, + title={Efficient Neural Audio Synthesis}, + author={Nal Kalchbrenner and Erich Elsen and Karen Simonyan and Seb Noury and Norman Casagrande and Edward Lockhart and Florian Stimberg and Aaron van den Oord and Sander Dieleman and Koray Kavukcuoglu}, + year={2018}, + eprint={1802.08435}, + archivePrefix={arXiv}, + primaryClass={cs.SD} +} +@article{Luo_2019, + title={Conv-TasNet: Surpassing Ideal Time–Frequency Magnitude Masking for Speech Separation}, + volume={27}, + ISSN={2329-9304}, + url={http://dx.doi.org/10.1109/TASLP.2019.2915167}, + DOI={10.1109/taslp.2019.2915167}, + number={8}, + journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing}, + publisher={Institute of Electrical and Electronics Engineers (IEEE)}, + author={Luo, Yi and Mesgarani, Nima}, + year={2019}, + month={Aug}, + pages={1256–1266} +} +@InProceedings{ brian_mcfee-proc-scipy-2015, + author = { {B}rian {M}c{F}ee and {C}olin {R}affel and {D}awen {L}iang and {D}aniel {P}.{W}. {E}llis and {M}att {M}c{V}icar and {E}ric {B}attenberg and {O}riol {N}ieto }, + title = { librosa: {A}udio and {M}usic {S}ignal {A}nalysis in {P}ython }, + booktitle = { {P}roceedings of the 14th {P}ython in {S}cience {C}onference }, + pages = { 18 - 24 }, + year = { 2015 }, + editor = { {K}athryn {H}uff and {J}ames {B}ergstra }, + doi = { 10.25080/Majora-7b98e3ed-003 } +} +@INPROCEEDINGS{6701851, + author={Perraudin, Nathanaël and Balazs, Peter and Søndergaard, Peter L.}, + booktitle={2013 IEEE Workshop on Applications of Signal Processing to Audio and Acoustics}, + title={A fast Griffin-Lim algorithm}, + year={2013}, + volume={}, + number={}, + pages={1-4}, + doi={10.1109/WASPAA.2013.6701851}} +@INPROCEEDINGS{1172092, + author={Griffin, D. and Jae Lim}, + booktitle={ICASSP '83. IEEE International Conference on Acoustics, Speech, and Signal Processing}, + title={Signal estimation from modified short-time Fourier transform}, + year={1983}, + volume={8}, + number={}, + pages={804-807}, + doi={10.1109/ICASSP.1983.1172092}} +@INPROCEEDINGS{6854049, + author={Ghahremani, Pegah and BabaAli, Bagher and Povey, Daniel and Riedhammer, Korbinian and Trmal, Jan and Khudanpur, Sanjeev}, + booktitle={2014 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, + title={A pitch extraction algorithm tuned for automatic speech recognition}, + year={2014}, + volume={}, + number={}, + pages={2494-2498}, + doi={10.1109/ICASSP.2014.6854049}} +@inproceedings{shen2018natural, + title={Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions}, + author={Shen, Jonathan and Pang, Ruoming and Weiss, Ron J and Schuster, Mike and Jaitly, Navdeep and Yang, Zongheng and Chen, Zhifeng and Zhang, Yu and Wang, Yuxuan and Skerrv-Ryan, Rj and others}, + booktitle={2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, + pages={4779--4783}, + year={2018}, + organization={IEEE} +} +@inproceedings{souden2009optimal, + title={On optimal frequency-domain multichannel linear filtering for noise reduction}, + author={Souden, Mehrez and Benesty, Jacob and Affes, Sofiene}, + booktitle={IEEE Transactions on audio, speech, and language processing}, + volume={18}, + number={2}, + pages={260--276}, + year={2009}, + publisher={IEEE} +} +@inproceedings{higuchi2016robust, + title={Robust MVDR beamforming using time-frequency masks for online/offline ASR in noise}, + author={Higuchi, Takuya and Ito, Nobutaka and Yoshioka, Takuya and Nakatani, Tomohiro}, + booktitle={2016 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, + pages={5210--5214}, + year={2016}, + organization={IEEE} +} +@article{mises1929praktische, + title={Praktische Verfahren der Gleichungsaufl{\"o}sung.}, + author={Mises, RV and Pollaczek-Geiringer, Hilda}, + journal={ZAMM-Journal of Applied Mathematics and Mechanics/Zeitschrift f{\"u}r Angewandte Mathematik und Mechanik}, + volume={9}, + number={1}, + pages={58--77}, + year={1929}, + publisher={Wiley Online Library} +} +@article{higuchi2017online, + title={Online MVDR beamformer based on complex Gaussian mixture model with spatial prior for noise robust ASR}, + author={Higuchi, Takuya and Ito, Nobutaka and Araki, Shoko and Yoshioka, Takuya and Delcroix, Marc and Nakatani, Tomohiro}, + journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing}, + volume={25}, + number={4}, + pages={780--793}, + year={2017}, + publisher={IEEE} +} diff --git a/docs/source/sox_effects.rst b/docs/source/sox_effects.rst new file mode 100644 index 0000000000000000000000000000000000000000..6eee11d8c7d97c4d5129837393d2402c0056fc60 --- /dev/null +++ b/docs/source/sox_effects.rst @@ -0,0 +1,33 @@ +.. _sox_effects: + +torchaudio.sox_effects +====================== + +.. currentmodule:: torchaudio.sox_effects + +Resource initialization / shutdown +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: init_sox_effects + +.. autofunction:: shutdown_sox_effects + +Listing supported effects +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: effect_names + +Applying effects +~~~~~~~~~~~~~~~~ + +Apply SoX effects chain on torch.Tensor or on file and load as torch.Tensor. + +Applying effects on Tensor +-------------------------- + +.. autofunction:: apply_effects_tensor + +Applying effects on file +------------------------ + +.. autofunction:: apply_effects_file diff --git a/docs/source/torchaudio.rst b/docs/source/torchaudio.rst new file mode 100644 index 0000000000000000000000000000000000000000..cb616d01aabad9eb46b96313c6c2b5b24b892740 --- /dev/null +++ b/docs/source/torchaudio.rst @@ -0,0 +1,32 @@ +torchaudio +========== + +I/O functionalities +~~~~~~~~~~~~~~~~~~~ + +Audio I/O functions are implemented in :ref:`torchaudio.backend` module, but for the ease of use, the following functions are made available on :mod:`torchaudio` module. There are different backends available and you can switch backends with :func:`set_audio_backend`. + +Refer to :ref:`backend` for the detail. + +.. function:: torchaudio.info(filepath: str, ...) + + Fetch meta data of an audio file. Refer to :ref:`backend` for the detail. + +.. function:: torchaudio.load(filepath: str, ...) + + Load audio file into torch.Tensor object. Refer to :ref:`backend` for the detail. + +.. function:: torchaudio.save(filepath: str, src: torch.Tensor, sample_rate: int, ...) + + Save torch.Tensor object into an audio format. Refer to :ref:`backend` for the detail. + +.. currentmodule:: torchaudio + +Backend Utilities +~~~~~~~~~~~~~~~~~ + +.. autofunction:: list_audio_backends + +.. autofunction:: get_audio_backend + +.. autofunction:: set_audio_backend diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst new file mode 100644 index 0000000000000000000000000000000000000000..f05c7eba9eedf465de12f5ede8c03c2a7ac3f25f --- /dev/null +++ b/docs/source/transforms.rst @@ -0,0 +1,211 @@ +.. role:: hidden + :class: hidden-section + +torchaudio.transforms +====================== + +.. currentmodule:: torchaudio.transforms + +Transforms are common audio transforms. They can be chained together using :class:`torch.nn.Sequential` + +:hidden:`Utility` +~~~~~~~~~~~~~~~~~~ + +:hidden:`AmplitudeToDB` +----------------------- + +.. autoclass:: AmplitudeToDB + + .. automethod:: forward + +:hidden:`MelScale` +------------------ + +.. autoclass:: MelScale + + .. automethod:: forward + +:hidden:`InverseMelScale` +------------------------- + +.. autoclass:: InverseMelScale + + .. automethod:: forward + +:hidden:`MuLawEncoding` +----------------------- + +.. autoclass:: MuLawEncoding + + .. automethod:: forward + +:hidden:`MuLawDecoding` +----------------------- + +.. autoclass:: MuLawDecoding + + .. automethod:: forward + +:hidden:`Resample` +------------------ + +.. autoclass:: Resample + + .. automethod:: forward + +:hidden:`FrequencyMasking` +-------------------------- + +.. autoclass:: FrequencyMasking + + .. automethod:: forward + +:hidden:`TimeMasking` +--------------------- + +.. autoclass:: TimeMasking + + .. automethod:: forward + +:hidden:`TimeStretch` +--------------------- + +.. autoclass:: TimeStretch + + .. automethod:: forward + +:hidden:`Fade` +-------------- + +.. autoclass:: Fade + + .. automethod:: forward + +:hidden:`Vol` +------------- + +.. autoclass:: Vol + + .. automethod:: forward + +:hidden:`Complex Utility` +~~~~~~~~~~~~~~~~~~~~~~~~~ + +:hidden:`ComplexNorm` +--------------------- + +.. autoclass:: ComplexNorm + + .. automethod:: forward + +:hidden:`Feature Extractions` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +:hidden:`Spectrogram` +--------------------- + +.. autoclass:: Spectrogram + + .. automethod:: forward + +:hidden:`InverseSpectrogram` +---------------------------- + +.. autoclass:: InverseSpectrogram + + .. automethod:: forward + +:hidden:`MelSpectrogram` +------------------------ + +.. autoclass:: MelSpectrogram + + .. automethod:: forward + +:hidden:`GriffinLim` +-------------------- + +.. autoclass:: GriffinLim + + .. automethod:: forward + +:hidden:`MFCC` +-------------- + +.. autoclass:: MFCC + + .. automethod:: forward + +:hidden:`LFCC` +-------------- + +.. autoclass:: LFCC + + .. automethod:: forward + +:hidden:`ComputeDeltas` +----------------------- + +.. autoclass:: ComputeDeltas + + .. automethod:: forward + +:hidden:`PitchShift` +-------------------- + +.. autoclass:: PitchShift + + .. automethod:: forward + +:hidden:`SlidingWindowCmn` +-------------------------- + +.. autoclass:: SlidingWindowCmn + + .. automethod:: forward + +:hidden:`SpectralCentroid` +-------------------------- + +.. autoclass:: SpectralCentroid + + .. automethod:: forward + +:hidden:`Vad` +------------- + +.. autoclass:: Vad + + .. automethod:: forward + +:hidden:`Loss` +~~~~~~~~~~~~~~ + +:hidden:`RNNTLoss` +------------------ + +.. autoclass:: RNNTLoss + + .. automethod:: forward + +:hidden:`Multi-channel` +~~~~~~~~~~~~~~~~~~~~~~~ + +:hidden:`PSD` +------------- + +.. autoclass:: PSD + + .. automethod:: forward + +:hidden:`MVDR` +-------------- + +.. autoclass:: MVDR + + .. automethod:: forward + +References +~~~~~~~~~~ + +.. footbibliography:: diff --git a/docs/source/utils.rst b/docs/source/utils.rst new file mode 100644 index 0000000000000000000000000000000000000000..dc5ad0fd73213e33f7b574864e309db08ff649d7 --- /dev/null +++ b/docs/source/utils.rst @@ -0,0 +1,11 @@ +torchaudio.utils +================ + +torchaudio.utils.sox_utils +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Utility module to configure libsox. +This affects functionalities in :ref:`Sox IO backend` and :ref:`Sox Effects`. + +.. automodule:: torchaudio.utils.sox_utils + :members: diff --git a/examples/beamforming/MVDR_tutorial.ipynb b/examples/beamforming/MVDR_tutorial.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..1cd69ee1f8a717832d0b8cc15be3de9ea8e117a1 --- /dev/null +++ b/examples/beamforming/MVDR_tutorial.ipynb @@ -0,0 +1,578 @@ +{ + "nbformat": 4, + "nbformat_minor": 2, + "metadata": { + "colab": { + "name": "Copy of Copy of torchaudio_MVDR_tutorial.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3.9.6 64-bit ('dev': conda)" + }, + "language_info": { + "name": "python", + "version": "3.9.6", + "mimetype": "text/x-python", + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "pygments_lexer": "ipython3", + "nbconvert_exporter": "python", + "file_extension": ".py" + }, + "interpreter": { + "hash": "6a702c257b9a40163843ba760790c17a6ddd2abeef8febce55475eea4b92c28c" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "\"Open" + ], + "metadata": { + "id": "xheYDPUcYGbp" + } + }, + { + "cell_type": "markdown", + "source": [ + "This is a tutorial on how to apply MVDR beamforming by using [torchaudio](https://github.com/pytorch/audio)\n", + "-----------\n", + "\n", + "The multi-channel audio example is selected from [ConferencingSpeech](https://github.com/ConferencingSpeech/ConferencingSpeech2021) dataset. \n", + "\n", + "```\n", + "original filename: SSB07200001\\#noise-sound-bible-0038\\#7.86_6.16_3.00_3.14_4.84_134.5285_191.7899_0.4735\\#15217\\#25.16333303751458\\#0.2101221178590021.wav\n", + "```\n", + "\n", + "Note:\n", + "- You need to use the nightly torchaudio in order to use the MVDR and InverseSpectrogram modules.\n", + "\n", + "\n", + "Steps\n", + "\n", + "- Ideal Ratio Mask (IRM) is generated by dividing the clean/noise magnitude by the mixture magnitude.\n", + "- We test all three solutions (``ref_channel``, ``stv_evd``, ``stv_power``) of torchaudio's MVDR module.\n", + "- We test the single-channel and multi-channel masks for MVDR beamforming. The multi-channel mask is averaged along channel dimension when computing the covariance matrices of speech and noise, respectively." + ], + "metadata": { + "id": "L6R0MXe5Wr19" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "!pip install --pre torchaudio -f https://download.pytorch.org/whl/nightly/torch_nightly.html --force" + ], + "outputs": [], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "juO6PE9XLctD", + "outputId": "8777ba14-da99-4c18-d80f-b070ad9861af" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "import torch\n", + "import torchaudio\n", + "import IPython.display as ipd" + ], + "outputs": [], + "metadata": { + "id": "T4u4unhFMMBG" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Load audios of mixture, reverberated clean speech, and dry clean speech." + ], + "metadata": { + "id": "bDILVXkeg2s3" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "!curl -LJO https://github.com/nateanl/torchaudio_mvdr_tutorial/raw/main/wavs/mix.wav\n", + "!curl -LJO https://github.com/nateanl/torchaudio_mvdr_tutorial/raw/main/wavs/reverb_clean.wav\n", + "!curl -LJO https://github.com/nateanl/torchaudio_mvdr_tutorial/raw/main/wavs/clean.wav" + ], + "outputs": [], + "metadata": { + "id": "2XIyMa_VKv0c", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "404f46a6-e70c-4f80-af8d-d356408a9f18" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "mix, sr = torchaudio.load('mix.wav')\n", + "reverb_clean, sr2 = torchaudio.load('reverb_clean.wav')\n", + "clean, sr3 = torchaudio.load('clean.wav')\n", + "assert sr == sr2\n", + "noise = mix - reverb_clean" + ], + "outputs": [], + "metadata": { + "id": "iErB6UhQPtD3" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Note: The MVDR Module requires ``torch.cdouble`` dtype for noisy STFT. We need to convert the dtype of the waveforms to ``torch.double``" + ], + "metadata": { + "id": "Aq-x_fo5VkwL" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "mix = mix.to(torch.double)\n", + "noise = noise.to(torch.double)\n", + "clean = clean.to(torch.double)\n", + "reverb_clean = reverb_clean.to(torch.double)" + ], + "outputs": [], + "metadata": { + "id": "5c66pHcQV0P9" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Initilize the Spectrogram and InverseSpectrogram modules" + ], + "metadata": { + "id": "05D26we0V4P-" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "stft = torchaudio.transforms.Spectrogram(n_fft=1024, hop_length=256, return_complex=True, power=None)\n", + "istft = torchaudio.transforms.InverseSpectrogram(n_fft=1024, hop_length=256)" + ], + "outputs": [], + "metadata": { + "id": "NcGhD7_TUKd1" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Compute the complex-valued STFT of mixture, clean speech, and noise" + ], + "metadata": { + "id": "-dlJcuSNUCgA" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "spec_mix = stft(mix)\n", + "spec_clean = stft(clean)\n", + "spec_reverb_clean = stft(reverb_clean)\n", + "spec_noise = stft(noise)" + ], + "outputs": [], + "metadata": { + "id": "w1vO7w1BUKt4" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Generate the Ideal Ratio Mask (IRM)\n", + "Note: we found using the mask directly peforms better than using the square root of it. This is slightly different from the definition of IRM." + ], + "metadata": { + "id": "8SBchrDhURK1" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "def get_irms(spec_clean, spec_noise, spec_mix):\n", + " mag_mix = spec_mix.abs() ** 2\n", + " mag_clean = spec_clean.abs() ** 2\n", + " mag_noise = spec_noise.abs() ** 2\n", + " irm_speech = mag_clean / (mag_clean + mag_noise)\n", + " irm_noise = mag_noise / (mag_clean + mag_noise)\n", + "\n", + " return irm_speech, irm_noise" + ], + "outputs": [], + "metadata": { + "id": "2gB63BoWUmHZ" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Note: We use reverberant clean speech as the target here, you can also set it to dry clean speech" + ], + "metadata": { + "id": "reGMDyNCaE7L" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "irm_speech, irm_noise = get_irms(spec_reverb_clean, spec_noise, spec_mix)" + ], + "outputs": [], + "metadata": { + "id": "HSTCGy_5Uqzx" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Apply MVDR beamforming by using multi-channel masks" + ], + "metadata": { + "id": "1R5I_TmSUbS0" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "results_multi = {}\n", + "for solution in ['ref_channel', 'stv_evd', 'stv_power']:\n", + " mvdr = torchaudio.transforms.MVDR(ref_channel=0, solution=solution, multi_mask=True)\n", + " stft_est = mvdr(spec_mix, irm_speech, irm_noise)\n", + " est = istft(stft_est, length=mix.shape[-1])\n", + " results_multi[solution] = est" + ], + "outputs": [], + "metadata": { + "id": "SiWFZgCbadz7" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Apply MVDR beamforming by using single-channel masks \n", + "(We use the 1st channel as an example. The channel selection may depend on the design of the microphone array)" + ], + "metadata": { + "id": "Ukez6_lcUfna" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "results_single = {}\n", + "for solution in ['ref_channel', 'stv_evd', 'stv_power']:\n", + " mvdr = torchaudio.transforms.MVDR(ref_channel=0, solution=solution, multi_mask=False)\n", + " stft_est = mvdr(spec_mix, irm_speech[0], irm_noise[0])\n", + " est = istft(stft_est, length=mix.shape[-1])\n", + " results_single[solution] = est" + ], + "outputs": [], + "metadata": { + "id": "kLeNKsk-VLm5" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Compute Si-SDR scores" + ], + "metadata": { + "id": "uJjJNdYiUnf0" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "def si_sdr(estimate, reference, epsilon=1e-8):\n", + " estimate = estimate - estimate.mean()\n", + " reference = reference - reference.mean()\n", + " reference_pow = reference.pow(2).mean(axis=1, keepdim=True)\n", + " mix_pow = (estimate * reference).mean(axis=1, keepdim=True)\n", + " scale = mix_pow / (reference_pow + epsilon)\n", + "\n", + " reference = scale * reference\n", + " error = estimate - reference\n", + "\n", + " reference_pow = reference.pow(2)\n", + " error_pow = error.pow(2)\n", + "\n", + " reference_pow = reference_pow.mean(axis=1)\n", + " error_pow = error_pow.mean(axis=1)\n", + "\n", + " sisdr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)\n", + " return sisdr.item()" + ], + "outputs": [], + "metadata": { + "id": "MgmAJcyiU-FU" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Single-channel mask results" + ], + "metadata": { + "id": "3TCJEwTOUxci" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "for solution in results_single:\n", + " print(solution+\": \", si_sdr(results_single[solution][None,...], reverb_clean[0:1]))" + ], + "outputs": [], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "NrUXXj98VVY7", + "outputId": "bc113347-70e3-47a9-8479-8aeeeca80abf" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Multi-channel mask results" + ], + "metadata": { + "id": "-7AnjM-gU3c8" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "for solution in results_multi:\n", + " print(solution+\": \", si_sdr(results_multi[solution][None,...], reverb_clean[0:1]))" + ], + "outputs": [], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "S_VINTnlXobM", + "outputId": "234b5615-63e7-44d8-f816-a6cc05999e52" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Display the mixture audio" + ], + "metadata": { + "id": "_vOK8vgmU_UP" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "print(\"Mixture speech\")\n", + "ipd.Audio(mix[0], rate=16000)" + ], + "outputs": [], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 92 + }, + "id": "QaKauQIHYctE", + "outputId": "674c7f9b-62a3-4298-81ac-d3ab1ee43cd7" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Display the noise" + ], + "metadata": { + "id": "R-QGGm87VFQI" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "print(\"Noise\")\n", + "ipd.Audio(noise[0], rate=16000)" + ], + "outputs": [], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 92 + }, + "id": "l1WgzxIZYhlk", + "outputId": "7b100679-b4a0-47ff-b30b-9f4cb9dca3d1" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Display the clean speech" + ], + "metadata": { + "id": "P3kB-jzpVKKu" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "print(\"Clean speech\")\n", + "ipd.Audio(clean[0], rate=16000)" + ], + "outputs": [], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 92 + }, + "id": "pwAWvlRAVJkT", + "outputId": "5e173a1b-2ba8-4797-8f3a-e41cbf05ac2b" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Display the enhanced audios¶" + ], + "metadata": { + "id": "RIlyzL1wVTnr" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "print(\"multi-channel mask, ref_channel solution\")\n", + "ipd.Audio(results_multi['ref_channel'], rate=16000)" + ], + "outputs": [], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 92 + }, + "id": "M3YQsledVIQ5", + "outputId": "43d9ee34-6933-401b-baf9-e4cdb7d79b63" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "print(\"multi-channel mask, stv_evd solution\")\n", + "ipd.Audio(results_multi['stv_evd'], rate=16000)" + ], + "outputs": [], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 92 + }, + "id": "UhYOHLvCVWBN", + "outputId": "761468ec-ebf9-4b31-ad71-bfa2e15fed37" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "print(\"multi-channel mask, stv_power solution\")\n", + "ipd.Audio(results_multi['stv_power'], rate=16000)" + ], + "outputs": [], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 92 + }, + "id": "9dv8VDtCVXzd", + "outputId": "1ae61ea3-d3c4-479f-faad-7439f942aac1" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "print(\"single-channel mask, ref_channel solution\")\n", + "ipd.Audio(results_single['ref_channel'], rate=16000)" + ], + "outputs": [], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 92 + }, + "id": "jCFUN890VZdh", + "outputId": "c0d2a928-5dd0-4584-b277-7838ac4a9e6b" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "print(\"single-channel mask, stv_evd solution\")\n", + "ipd.Audio(results_single['stv_evd'], rate=16000)" + ], + "outputs": [], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 92 + }, + "id": "hzlzagsKVbAv", + "outputId": "96af9e37-82ca-4544-9c08-421fe222bde4" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "print(\"single-channel mask, stv_power solution\")\n", + "ipd.Audio(results_single['stv_power'], rate=16000)" + ], + "outputs": [], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 92 + }, + "id": "A4igQpTnVctG", + "outputId": "cf968089-9274-4c1c-a1a5-32b220de0bf9" + } + } + ] +} \ No newline at end of file diff --git a/examples/interactive_asr/README.md b/examples/interactive_asr/README.md new file mode 100644 index 0000000000000000000000000000000000000000..39a5c53b75fab6f78e9c23571138ebc356da2d8d --- /dev/null +++ b/examples/interactive_asr/README.md @@ -0,0 +1,63 @@ +# asr-demo + +To run this demo, you need the following libraries +- [python3](https://www.python.org/download/releases/3.0/) +- [pyaudio](https://people.csail.mit.edu/hubert/pyaudio/) +- [torchaudio](https://github.com/pytorch/audio/tree/master/torchaudio) +- [pytorch](https://pytorch.org/) +- [librosa](https://librosa.github.io/librosa/) +- [fairseq](https://github.com/pytorch/fairseq) (clone the github repository) +and the following models +- [dictionary](https://download.pytorch.org/models/audio/dict.txt) +- [sentence piece model](https://download.pytorch.org/models/audio/spm.model) +- [model](https://download.pytorch.org/models/audio/checkpoint_avg_60_80.pt) + +## Installation + +We recommend that you use [conda](https://docs.conda.io/en/latest/miniconda.html) to install the dependencies when available. +```bash +# Assume that all commands are from the examples folder +cd examples + +# Install dependencies +conda install -c pytorch torchaudio +conda install -c conda-forge librosa +conda install pyaudio +pip install sentencepiece + +# Install fairseq from source +git clone https://github.com/pytorch/fairseq interactive_asr/fairseq +pushd interactive_asr/fairseq +export CFLAGS='-stdlib=libc++' # For Mac only +pip install --editable . +popd + +# Install dictionary, sentence piece model, and model +wget -O interactive_asr/data/dict.txt https://download.pytorch.org/models/audio/dict.txt +wget -O interactive_asr/data/spm.model https://download.pytorch.org/models/audio/spm.model +wget -O interactive_asr/data/model.pt https://download.pytorch.org/models/audio/checkpoint_avg_60_80.pt +``` + +## Run +On a file +```bash +INPUT_FILE=interactive_asr/data/sample.wav +python -m interactive_asr.asr interactive_asr/data --input_file $INPUT_FILE --max-tokens 10000000 --nbest 1 \ + --path interactive_asr/data/model.pt --beam 40 --task speech_recognition \ + --user-dir interactive_asr/fairseq/examples/speech_recognition +``` + +As a microphone +```bash +python -m interactive_asr.asr interactive_asr/data --max-tokens 10000000 --nbest 1 \ + --path interactive_asr/data/model.pt --beam 40 --task speech_recognition \ + --user-dir interactive_asr/fairseq/examples/speech_recognition +``` +To run the testcase associated with this example +```bash +ASR_MODEL_PATH=interactive_asr/data/model.pt \ +ASR_INPUT_FILE=interactive_asr/data/sample.wav \ +ASR_DATA_PATH=interactive_asr/data \ +ASR_USER_DIR=interactive_asr/fairseq/examples/speech_recognition \ +python -m unittest test/test_interactive_asr.py +``` diff --git a/examples/interactive_asr/__init__.py b/examples/interactive_asr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..57e1c91022b1dc8774cdc47a8ff6a34707c00920 --- /dev/null +++ b/examples/interactive_asr/__init__.py @@ -0,0 +1,3 @@ +from . import utils, vad + +__all__ = ['utils', 'vad'] diff --git a/examples/interactive_asr/asr.py b/examples/interactive_asr/asr.py new file mode 100644 index 0000000000000000000000000000000000000000..8e2510f347b5d0d09900f9d8dc97b526c867a91d --- /dev/null +++ b/examples/interactive_asr/asr.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. +""" +Run inference for pre-processed data with a trained model. +""" + +import datetime as dt +import logging + +from fairseq import options + +from interactive_asr.utils import add_asr_eval_argument, setup_asr, get_microphone_transcription, transcribe_file + + +def main(args): + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + task, generator, models, sp, tgt_dict = setup_asr(args, logger) + + print("READY!") + if args.input_file: + transcription_time, transcription = transcribe_file(args, task, generator, models, sp, tgt_dict) + print("transcription:", transcription) + print("transcription_time:", transcription_time) + else: + for transcription in get_microphone_transcription(args, task, generator, models, sp, tgt_dict): + print( + "{}: {}".format( + dt.datetime.now().strftime("%H:%M:%S"), transcription[0][0] + ) + ) + + +def cli_main(): + parser = options.get_generation_parser() + parser = add_asr_eval_argument(parser) + args = options.parse_args_and_arch(parser) + main(args) + + +if __name__ == "__main__": + cli_main() diff --git a/examples/interactive_asr/data/sample.wav b/examples/interactive_asr/data/sample.wav new file mode 100644 index 0000000000000000000000000000000000000000..d52e38dbb09005e008ce437a5c6c683e040968d4 Binary files /dev/null and b/examples/interactive_asr/data/sample.wav differ diff --git a/examples/interactive_asr/utils.py b/examples/interactive_asr/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3aae6c6864b405648104214f839c1a7eb979c3a7 --- /dev/null +++ b/examples/interactive_asr/utils.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. +import os +import sys +import time + +import torch +import torchaudio +import sentencepiece as spm + +from fairseq import tasks +from fairseq.utils import load_ensemble_for_inference, import_user_module + +from interactive_asr.vad import get_microphone_chunks + + +def add_asr_eval_argument(parser): + parser.add_argument("--input_file", help="input file") + parser.add_argument("--ctc", action="store_true", help="decode a ctc model") + parser.add_argument("--rnnt", default=False, help="decode a rnnt model") + parser.add_argument("--kspmodel", default=None, help="sentence piece model") + parser.add_argument( + "--wfstlm", default=None, help="wfstlm on dictonary output units" + ) + parser.add_argument( + "--rnnt_decoding_type", + default="greedy", + help="wfstlm on dictonary output units", + ) + parser.add_argument( + "--lm_weight", + default=0.2, + help="weight for wfstlm while interpolating with neural score", + ) + parser.add_argument( + "--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level" + ) + return parser + + +def check_args(args): + assert args.path is not None, "--path required for generation!" + assert ( + not args.sampling or args.nbest == args.beam + ), "--sampling requires --nbest to be equal to --beam" + assert ( + args.replace_unk is None or args.raw_text + ), "--replace-unk requires a raw text dataset (--raw-text)" + + +def process_predictions(args, hypos, sp, tgt_dict): + res = [] + device = torch.device("cuda:0" if torch.cuda.is_available() and not args.cpu else "cpu") + for hypo in hypos[: min(len(hypos), args.nbest)]: + hyp_pieces = tgt_dict.string(hypo["tokens"].int().to(device)) + hyp_words = sp.DecodePieces(hyp_pieces.split()) + res.append(hyp_words) + return res + + +def optimize_models(args, use_cuda, models): + """Optimize ensemble for generation + """ + for model in models: + model.make_generation_fast_( + beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, + need_attn=args.print_alignment, + ) + if args.fp16: + model.half() + if use_cuda: + model.cuda() + + +def calc_mean_invstddev(feature): + if len(feature.shape) != 2: + raise ValueError("We expect the input feature to be 2-D tensor") + mean = torch.mean(feature, dim=0) + var = torch.var(feature, dim=0) + # avoid division by ~zero + if (var < sys.float_info.epsilon).any(): + return mean, 1.0 / (torch.sqrt(var) + sys.float_info.epsilon) + return mean, 1.0 / torch.sqrt(var) + + +def calcMN(features): + mean, invstddev = calc_mean_invstddev(features) + res = (features - mean) * invstddev + return res + + +def transcribe(waveform, args, task, generator, models, sp, tgt_dict): + num_features = 80 + output = torchaudio.compliance.kaldi.fbank(waveform, num_mel_bins=num_features) + device = torch.device("cuda:0" if torch.cuda.is_available() and not args.cpu else "cpu") + output_cmvn = calcMN(output.to(device).detach()) + + # size (m, n) + source = output_cmvn + frames_lengths = torch.LongTensor([source.size(0)]) + + # size (1, m, n). In general, if source is (x, m, n), then hypos is (x, ...) + source.unsqueeze_(0) + sample = {"net_input": {"src_tokens": source, "src_lengths": frames_lengths}} + + hypos = task.inference_step(generator, models, sample) + + assert len(hypos) == 1 + transcription = [] + for i in range(len(hypos)): + # Process top predictions + hyp_words = process_predictions(args, hypos[i], sp, tgt_dict) + transcription.append(hyp_words) + + return transcription + + +def setup_asr(args, logger): + check_args(args) + import_user_module(args) + + if args.max_tokens is None and args.batch_size is None: + args.max_tokens = 30000 + logger.info(args) + + use_cuda = torch.cuda.is_available() and not args.cpu + + # Load dataset splits + task = tasks.setup_task(args) + + # Set dictionary + tgt_dict = task.target_dictionary + + if args.ctc or args.rnnt: + tgt_dict.add_symbol("") + if args.ctc: + logger.info("| decoding a ctc model") + if args.rnnt: + logger.info("| decoding a rnnt model") + + # Load ensemble + logger.info("| loading model(s) from {}".format(args.path)) + models, _model_args = load_ensemble_for_inference( + args.path.split(":"), + task, + model_arg_overrides=eval(args.model_overrides), # noqa + ) + optimize_models(args, use_cuda, models) + + # Initialize generator + generator = task.build_generator(models, args) + + sp = spm.SentencePieceProcessor() + sp.Load(os.path.join(args.data, "spm.model")) + return task, generator, models, sp, tgt_dict + + +def transcribe_file(args, task, generator, models, sp, tgt_dict): + path = args.input_file + if not os.path.exists(path): + raise FileNotFoundError("Audio file not found: {}".format(path)) + waveform, sample_rate = torchaudio.load_wav(path) + waveform = waveform.mean(0, True) + waveform = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=16000 + )(waveform) + + start = time.time() + transcription = transcribe( + waveform, args, task, generator, models, sp, tgt_dict + ) + transcription_time = time.time() - start + return transcription_time, transcription + + +def get_microphone_transcription(args, task, generator, models, sp, tgt_dict): + for (waveform, sample_rate) in get_microphone_chunks(): + waveform = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=16000 + )(waveform.reshape(1, -1)) + transcription = transcribe( + waveform, args, task, generator, models, sp, tgt_dict + ) + yield transcription diff --git a/examples/interactive_asr/vad.py b/examples/interactive_asr/vad.py new file mode 100644 index 0000000000000000000000000000000000000000..bc942032b009e366decfb84753abc605a2bdaa45 --- /dev/null +++ b/examples/interactive_asr/vad.py @@ -0,0 +1,269 @@ +#!/usr/bin/env python3 +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. +""" +Following `a simple but efficient real-time voice activity detection algorithm +`__. + +There are three criteria to decide if a frame contains speech: energy, most +dominant frequency, and spectral flatness. If any two of those are higher than +a minimum plus a threshold, then the frame contains speech. In the offline +case, the list of frames is postprocessed to remove too short silence and +speech sequences. In the online case here, inertia is added before switching +from speech to silence or vice versa. +""" + +from collections import deque + +import numpy as np +import torch +import queue + +import librosa +import pyaudio +import torchaudio + + +def compute_spectral_flatness(frame, epsilon=0.01): + # epsilon protects against log(0) + geometric_mean = torch.exp((frame + epsilon).log().mean(-1)) - epsilon + arithmetic_mean = frame.mean(-1) + return -10 * torch.log10(epsilon + geometric_mean / arithmetic_mean) + + +class VoiceActivityDetection: + def __init__( + self, + num_init_frames=30, + ignore_silent_count=4, + ignore_speech_count=1, + energy_prim_thresh=60, + frequency_prim_thresh=10, + spectral_flatness_prim_thresh=3, + verbose=False, + ): + + self.num_init_frames = num_init_frames + self.ignore_silent_count = ignore_silent_count + self.ignore_speech_count = ignore_speech_count + + self.energy_prim_thresh = energy_prim_thresh + self.frequency_prim_thresh = frequency_prim_thresh + self.spectral_flatness_prim_thresh = spectral_flatness_prim_thresh + + self.verbose = verbose + + self.speech_mark = True + self.silence_mark = False + + self.silent_count = 0 + self.speech_count = 0 + self.n = 0 + + if self.verbose: + self.energy_list = [] + self.frequency_list = [] + self.spectral_flatness_list = [] + + def iter(self, frame): + + frame_fft = torch.rfft(frame, 1) + amplitudes = torchaudio.functional.complex_norm(frame_fft) + + # Compute frame energy + energy = frame.pow(2).sum(-1) + + # Most dominant frequency component + frequency = amplitudes.argmax() + + # Spectral flatness measure + spectral_flatness = compute_spectral_flatness(amplitudes) + + if self.verbose: + self.energy_list.append(energy) + self.frequency_list.append(frequency) + self.spectral_flatness_list.append(spectral_flatness) + + if self.n == 0: + self.min_energy = energy + self.min_frequency = frequency + self.min_spectral_flatness = spectral_flatness + elif self.n < self.num_init_frames: + self.min_energy = min(energy, self.min_energy) + self.min_frequency = min(frequency, self.min_frequency) + self.min_spectral_flatness = min( + spectral_flatness, self.min_spectral_flatness + ) + + self.n += 1 + + # Add 1. to avoid log(0) + thresh_energy = self.energy_prim_thresh * torch.log(1.0 + self.min_energy) + thresh_frequency = self.frequency_prim_thresh + thresh_spectral_flatness = self.spectral_flatness_prim_thresh + + # Check all three conditions + + counter = 0 + if energy - self.min_energy >= thresh_energy: + counter += 1 + if frequency - self.min_frequency >= thresh_frequency: + counter += 1 + if spectral_flatness - self.min_spectral_flatness >= thresh_spectral_flatness: + counter += 1 + + # Detection + if counter > 1: + # Speech detected + self.speech_count += 1 + # Inertia against switching + if ( + self.n >= self.num_init_frames + and self.speech_count <= self.ignore_speech_count + ): + # Too soon to change + return self.silence_mark + else: + self.silent_count = 0 + return self.speech_mark + else: + # Silence detected + self.min_energy = ((self.silent_count * self.min_energy) + energy) / ( + self.silent_count + 1 + ) + self.silent_count += 1 + # Inertia against switching + if ( + self.n >= self.num_init_frames + and self.silent_count <= self.ignore_silent_count + ): + # Too soon to change + return self.speech_mark + else: + self.speech_count = 0 + return self.silence_mark + + +class MicrophoneStream: + """Opens a recording stream as a generator yielding the audio chunks.""" + + def __init__(self, device=None, rate=22050, chunk=2205): + """ + The 22050 is the librosa default, which is what our models were + trained on. The ratio of [chunk / rate] is the amount of time between + audio samples - for example, with these defaults, + an audio fragment will be processed every tenth of a second. + """ + self._rate = rate + self._chunk = chunk + self._device = device + + # Create a thread-safe buffer of audio data + self._buff = queue.Queue() + self.closed = True + + def __enter__(self): + self._audio_interface = pyaudio.PyAudio() + self._audio_stream = self._audio_interface.open( + # format=pyaudio.paInt16, + format=pyaudio.paFloat32, + # The API currently only supports 1-channel (mono) audio + # https://goo.gl/z757pE + channels=1, + rate=self._rate, + input=True, + frames_per_buffer=self._chunk, + input_device_index=self._device, + # Run the audio stream asynchronously to fill the buffer object. + # This is necessary so that the input device's buffer doesn't + # overflow while the calling thread makes network requests, etc. + stream_callback=self._fill_buffer, + ) + + self.closed = False + + return self + + def __exit__(self, type, value, traceback): + self._audio_stream.stop_stream() + self._audio_stream.close() + self.closed = True + # Signal the generator to terminate so that the client's + # streaming_recognize method will not block the process termination. + self._buff.put(None) + self._audio_interface.terminate() + + def _fill_buffer(self, in_data, frame_count, time_info, status_flags): + """Continuously collect data from the audio stream, into the buffer.""" + self._buff.put(in_data) + return None, pyaudio.paContinue + + def generator(self): + while not self.closed: + # Use a blocking get() to ensure there's at least one chunk of + # data, and stop iteration if the chunk is None, indicating the + # end of the audio stream. + chunk = self._buff.get() + if chunk is None: + return + data = [chunk] + + # Now consume whatever other data's still buffered. + while True: + try: + chunk = self._buff.get(block=False) + if chunk is None: + return + data.append(chunk) + except queue.Empty: + break + + ans = np.fromstring(b"".join(data), dtype=np.float32) + # yield uniform-sized chunks + ans = np.split(ans, np.shape(ans)[0] / self._chunk) + # Resample the audio to 22050, librosa default + for chunk in ans: + yield librosa.core.resample(chunk, self._rate, 22050) + + +def get_microphone_chunks( + min_to_cumulate=5, # 0.5 seconds + max_to_cumulate=100, # 10 seconds + precumulate=5, + max_to_visualize=100, +): + + vad = VoiceActivityDetection() + + cumulated = [] + precumulated = deque(maxlen=precumulate) + + with MicrophoneStream() as stream: + audio_generator = stream.generator() + chunk_length = stream._chunk + waveform = torch.zeros(max_to_visualize * chunk_length) + + for chunk in audio_generator: + # Is speech? + + chunk = torch.tensor(chunk) + is_speech = vad.iter(chunk) + + # Cumulate speech + + if is_speech or cumulated: + cumulated.append(chunk) + else: + precumulated.append(chunk) + + if (not is_speech and len(cumulated) >= min_to_cumulate) or ( + len(cumulated) > max_to_cumulate + ): + waveform = torch.cat(list(precumulated) + cumulated, -1) + yield (waveform * stream._rate, stream._rate) + cumulated = [] + precumulated = deque(maxlen=precumulate) diff --git a/examples/libtorchaudio/.gitignore b/examples/libtorchaudio/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..d2c4084e758ea68d5c659951f6f15e8883e19c83 --- /dev/null +++ b/examples/libtorchaudio/.gitignore @@ -0,0 +1,4 @@ +build +data/output.wav +*.zip +output diff --git a/examples/libtorchaudio/CMakeLists.txt b/examples/libtorchaudio/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..b4cf58b375af2a967ad9da5b97f1f11615cb3e4f --- /dev/null +++ b/examples/libtorchaudio/CMakeLists.txt @@ -0,0 +1,17 @@ +cmake_minimum_required(VERSION 3.5) + +project(libtorchaudio-cpp-example) + +SET(BUILD_SOX ON CACHE BOOL "Build libsox into libtorchaudio") + +SET(BUILD_KALDI OFF CACHE BOOL "Build Kaldi into libtorchaudio") +SET(BUILD_RNNT ON CACHE BOOL "Build RNN transducer into libtorchaudio") +SET(BUILD_TORCHAUDIO_PYTHON_EXTENSION OFF CACHE BOOL "Build Python binding") + +find_package(Torch REQUIRED) +message("libtorchaudio CMakeLists: ${TORCH_CXX_FLAGS}") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") + +add_subdirectory(../.. libtorchaudio) +add_subdirectory(augmentation) +add_subdirectory(speech_recognition) diff --git a/examples/libtorchaudio/README.md b/examples/libtorchaudio/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cfed769cd4959447aa3d93f2c614b21e9061c918 --- /dev/null +++ b/examples/libtorchaudio/README.md @@ -0,0 +1,30 @@ +# Libtorchaudio Examples + +* [Augmentation](./augmentation) +* [Speech Recognition with wav2vec2.0](./speech_recognition) + +## Build + +The example applications in this directory depend on `libtorch` and `libtorchaudio`. +If you have a working `PyTorch`, you already have `libtorch`. +Please refer to [this tutorial](https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html) for the use of `libtorch` and TorchScript. + +`libtorchaudio` is the library of torchaudio's C++ components without Python component. +It is currently not distributed, and it will be built alongside with the applications. + +The following commands will build `libtorchaudio` and applications. + +```bash +git submodule update +mkdir build +cd build +cmake -GNinja \ + -DCMAKE_PREFIX_PATH="$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')" \ + -DBUILD_SOX=ON \ + -DBUILD_KALDI=OFF \ + -DBUILD_RNNT=ON \ + .. +cmake --build . +``` + +For the usages of each application, refer to the corresponding application directory. diff --git a/examples/libtorchaudio/augmentation/CMakeLists.txt b/examples/libtorchaudio/augmentation/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..e9bfece93ade84632c1e4bbdec5d71986f2fae10 --- /dev/null +++ b/examples/libtorchaudio/augmentation/CMakeLists.txt @@ -0,0 +1,3 @@ +add_executable(augment main.cpp) +target_link_libraries(augment "${TORCH_LIBRARIES}" "${TORCHAUDIO_LIBRARY}") +set_property(TARGET augment PROPERTY CXX_STANDARD 14) diff --git a/examples/libtorchaudio/augmentation/README.md b/examples/libtorchaudio/augmentation/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b78248af6bcfbaed87f8b7126a7d41fc151be0d0 --- /dev/null +++ b/examples/libtorchaudio/augmentation/README.md @@ -0,0 +1,36 @@ +# Augmentation + +This example demonstrates how you can use torchaudio's I/O features and augmentations in C++ application. + +**NOTE** +This example uses `"sox_io"` backend, thus does not work on Windows. + +## Steps +### 1. Create augmentation pipeline TorchScript file. + +First, we implement our data process pipeline as a regular Python, and save it as a TorchScript object. +We will load and execute it in our C++ application. The C++ code is found in [`main.cpp`](./main.cpp). + +```python +python create_jittable_pipeline.py \ + --rir-path "../data/rir.wav" \ + --output-path "./pipeline.zip" +``` + +### 2. Build the application + +Please refer to [the top level README.md](../README.md) + +### 3. Run the application + +Now we run the C++ application `augment`, with the TorchScript object we created in Step.1 and an input audio file. + +In [the top level directory](../) + +```bash +input_audio_file="./data/input.wav" +./build/augmentation/augment ./augmentation/pipeline.zip "${input_audio_file}" "output.wav" +``` + +When you give a clean speech file, the output audio sounds like it's a phone conversation. + diff --git a/examples/libtorchaudio/augmentation/create_jittable_pipeline.py b/examples/libtorchaudio/augmentation/create_jittable_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..ac33a837da8e2153c3a620c3c25a25a039be5a2f --- /dev/null +++ b/examples/libtorchaudio/augmentation/create_jittable_pipeline.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +""" +Create a data preprocess pipeline that can be run with libtorchaudio +""" +import os +import argparse + +import torch +import torchaudio + + +class Pipeline(torch.nn.Module): + """Example audio process pipeline. + + This example load waveform from a file then apply effects and save it to a file. + """ + def __init__(self, rir_path: str): + super().__init__() + rir, sample_rate = torchaudio.load(rir_path) + self.register_buffer('rir', rir) + self.rir_sample_rate: int = sample_rate + + def forward(self, input_path: str, output_path: str): + torchaudio.sox_effects.init_sox_effects() + + # 1. load audio + waveform, sample_rate = torchaudio.load(input_path) + + # 2. Add background noise + alpha = 0.01 + waveform = alpha * torch.randn_like(waveform) + (1 - alpha) * waveform + + # 3. Reample the RIR filter to much the audio sample rate + rir, _ = torchaudio.sox_effects.apply_effects_tensor( + self.rir, self.rir_sample_rate, effects=[["rate", str(sample_rate)]]) + rir = rir / torch.norm(rir, p=2) + rir = torch.flip(rir, [1]) + + # 4. Apply RIR filter + waveform = torch.nn.functional.pad(waveform, (rir.shape[1] - 1, 0)) + waveform = torch.nn.functional.conv1d(waveform[None, ...], rir[None, ...])[0] + + # Save + torchaudio.save(output_path, waveform, sample_rate) + + +def _create_jit_pipeline(rir_path, output_path): + module = torch.jit.script(Pipeline(rir_path)) + print("*" * 40) + print("* Pipeline code") + print("*" * 40) + print() + print(module.code) + print("*" * 40) + module.save(output_path) + + +def _get_path(*paths): + return os.path.join(os.path.dirname(__file__), *paths) + + +def _parse_args(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--rir-path", + default=_get_path("..", "data", "rir.wav"), + help="Audio dara for room impulse response." + ) + parser.add_argument( + "--output-path", + default=_get_path("pipeline.zip"), + help="Output JIT file." + ) + return parser.parse_args() + + +def _main(): + args = _parse_args() + _create_jit_pipeline(args.rir_path, args.output_path) + + +if __name__ == '__main__': + _main() diff --git a/examples/libtorchaudio/augmentation/main.cpp b/examples/libtorchaudio/augmentation/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fe45d28aa9491952e784756a0cdc240ad171e91b --- /dev/null +++ b/examples/libtorchaudio/augmentation/main.cpp @@ -0,0 +1,21 @@ +#include + +int main(int argc, char* argv[]) { + if (argc !=4) { + std::cerr << "Usage: " << argv[0] << " " << std::endl; + return -1; + } + + torch::jit::script::Module module; + std::cout << "Loading module from: " << argv[1] << std::endl; + try { + module = torch::jit::load(argv[1]); + } catch (const c10::Error &error) { + std::cerr << "Failed to load the module:" << error.what() << std::endl; + return -1; + } + + std::cout << "Performing the process ..." << std::endl; + module.forward({c10::IValue(argv[2]), c10::IValue(argv[3])}); + std::cout << "Done." << std::endl; +} diff --git a/examples/libtorchaudio/build.sh b/examples/libtorchaudio/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..ac51caf34c85b8aa3d28c23369e45917d08796ac --- /dev/null +++ b/examples/libtorchaudio/build.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash + +set -eux + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +build_dir="${this_dir}/build" + +mkdir -p "${build_dir}" +cd "${build_dir}" + +git submodule update +cmake -GNinja \ + -DCMAKE_PREFIX_PATH="$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')" \ + -DBUILD_SOX=ON \ + -DBUILD_KALDI=OFF \ + .. +cmake --build . diff --git a/examples/libtorchaudio/data/README.md b/examples/libtorchaudio/data/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6e2ff19097846020b08568cac86b48e57b8ca3b2 --- /dev/null +++ b/examples/libtorchaudio/data/README.md @@ -0,0 +1,5 @@ +The files in this directory are originated from [VOiCES](https://iqtlabs.github.io/voices/) dataset, which is licensed under Creative Commos BY 4.0. They are modified to fit into the tutorial. + +* `input.wav`: `VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav` + +* `rir.wav`: `VOiCES_devkit/distant-16k/room-response/rm1/impulse/Lab41-SRI-VOiCES-rm1-impulse-mc01-stu-clo.wav` diff --git a/examples/libtorchaudio/data/input.wav b/examples/libtorchaudio/data/input.wav new file mode 100644 index 0000000000000000000000000000000000000000..004a33532ea2547c10c0074b967733ba91edd9f8 Binary files /dev/null and b/examples/libtorchaudio/data/input.wav differ diff --git a/examples/libtorchaudio/data/rir.wav b/examples/libtorchaudio/data/rir.wav new file mode 100644 index 0000000000000000000000000000000000000000..c9e6d836818ff5ea8cfaaadeab69316561b32465 Binary files /dev/null and b/examples/libtorchaudio/data/rir.wav differ diff --git a/examples/libtorchaudio/speech_recognition/CMakeLists.txt b/examples/libtorchaudio/speech_recognition/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9da6ab2914977f4e3f36b1d05ca8bab435be7a3d --- /dev/null +++ b/examples/libtorchaudio/speech_recognition/CMakeLists.txt @@ -0,0 +1,6 @@ +add_executable(transcribe transcribe.cpp) +add_executable(transcribe_list transcribe_list.cpp) +target_link_libraries(transcribe "${TORCH_LIBRARIES}" "${TORCHAUDIO_LIBRARY}") +target_link_libraries(transcribe_list "${TORCH_LIBRARIES}" "${TORCHAUDIO_LIBRARY}") +set_property(TARGET transcribe PROPERTY CXX_STANDARD 14) +set_property(TARGET transcribe_list PROPERTY CXX_STANDARD 14) diff --git a/examples/libtorchaudio/speech_recognition/README.md b/examples/libtorchaudio/speech_recognition/README.md new file mode 100644 index 0000000000000000000000000000000000000000..124e754ef44f85eb0d82a860a887844543c1de85 --- /dev/null +++ b/examples/libtorchaudio/speech_recognition/README.md @@ -0,0 +1,187 @@ +# Speech Recognition with wav2vec2.0 + +This example demonstarates how you can use torchaudio's I/O features and models to run speech recognition in C++ application. + +**NOTE** +This example uses `"sox_io"` backend for loading audio, which does not work on Windows. To make it work on +Windows, you need to replace the part of loading audio and converting it to Tensor object. + +## 1. Create a transcription pipeline TorchScript file + +We will create a TorchScript that performs the following processes; + +1. Load audio from a file. +1. Pass audio to encoder which produces the sequence of probability distribution on labels. +1. Pass the encoder output to decoder which generates transcripts. + +For building decoder, we borrow the pre-trained weights published by `fairseq` and/or Hugging Face Transformers, then convert it `torchaudio`'s format, which supports TorchScript. + +### 1.1. From `fairseq` + +For `fairseq` models, you can download pre-trained weights +You can download a model from [`fairseq` repository](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec). Here, we will use `Base / 960h` model. You also need to download [the letter dictionary file](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec#evaluating-a-ctc-model). + +For the decoder part, we use [simple_ctc](https://github.com/mthrok/ctcdecode), which also supports TorchScript. + +```bash +mkdir -p pipeline-fairseq +python build_pipeline_from_fairseq.py \ + --model-file "wav2vec_small_960.pt" \ + --dict-dir \ + --output-path "./pipeline-fairseq/" +``` + +The above command should create the following TorchScript object files in the output directory. + +``` +decoder.zip encoder.zip loader.zip +``` + +* `loader.zip` loads audio file and generate waveform Tensor. +* `encoder.zip` receives waveform Tensor and generates the sequence of probability distribution over the label. +* `decoder.zip` receives the probability distribution over the label and generates a transcript. + +### 1.2. From Hugging Face Transformers + + +[Hugging Face Transformers](https://huggingface.co/transformers/index.html) and [Hugging Face Model Hub](https://huggingface.co/models) provides `wav2vec2.0` models fine-tuned on variety of datasets and languages. + +We can also import the model published on Hugging Face Hub and run it in our C++ application. +In the following example, we will try the Geremeny model, ([facebook/wav2vec2-large-xlsr-53-german](https://huggingface.co/facebook/wav2vec2-large-xlsr-53-german/tree/main)) on [VoxForge Germany dataset](http://www.voxforge.org/de/downloads). + +```bash +mkdir -p pipeline-hf +python build_pipeline_from_huggingface_transformers.py \ + --model facebook/wav2vec2-large-xlsr-53-german \ + --output-path ./pipeline-hf/ +``` + +The resulting TorchScript object files should be same as the `fairseq` example. + +## 2. Build the application + +Please refer to [the top level README.md](../README.md) + +## 3. Run the application + +Now we run the C++ application [`transcribe`](./transcribe.cpp), with the TorchScript object we created in Step.1.1. and an input audio file. + +```bash +../build/speech_recognition/transcribe ./pipeline-fairseq ../data/input.wav +``` + +This will output something like the following. + +``` +Loading module from: ./pipeline/loader.zip +Loading module from: ./pipeline/encoder.zip +Loading module from: ./pipeline/decoder.zip +Loading the audio +Running inference +Generating the transcription +I HAD THAT CURIOSITY BESIDE ME AT THIS MOMENT +Done. +``` + +## 4. Evaluate the pipeline on Librispeech dataset + +Let's evaluate this word error rate (WER) of this application using [Librispeech dataset](https://www.openslr.org/12). + +### 4.1. Create a list of audio paths + +For the sake of simplifying our C++ code, we will first parse the Librispeech dataset to get the list of audio path + +```bash +python parse_librispeech.py /LibriSpeech/test-clean ./flist.txt +``` + +The list should look like the following; + +```bash +head flist.txt + +1089-134691-0000 /LibriSpeech/test-clean/1089/134691/1089-134691-0000.flac HE COULD WAIT NO LONGER +``` + +### 4.2. Run the transcription + +[`transcribe_list`](./transcribe_list.cpp) processes the input flist list and feed the audio path one by one to the pipeline, then generate reference file and hypothesis file. + +```bash +../build/speech_recognition/transcribe_list ./pipeline-fairseq ./flist.txt +``` + +### 4.3. Score WER + +You need `sclite` for this step. You can download the code from [SCTK repository](https://github.com/usnistgov/SCTK). + +```bash +# in the output directory +sclite -r ref.trn -h hyp.trn -i wsj -o pralign -o sum +``` + +WER can be found in the resulting `hyp.trn.sys`. Check out the column that starts with `Sum/Avg` the first column of the third block is `100 - WER`. + +In our test, we got the following results. + +| model | Fine Tune | test-clean | test-other | +|:-----------------------------------------:|----------:|:----------:|:----------:| +| Base
`wav2vec_small_960` | 960h | 3.1 | 7.7 | +| Large
`wav2vec_big_960` | 960h | 2.6 | 5.9 | +| Large (LV-60)
`wav2vec2_vox_960h_new` | 960h | 2.9 | 6.2 | +| Large (LV-60) + Self Training
`wav2vec_vox_960h_pl` | 960h | 1.9 | 4.5 | + + +You can also check `hyp.trn.pra` file to see what errors were made. + +``` +id: (3528-168669-0005) +Scores: (#C #S #D #I) 7 1 0 0 +REF: there is a stone to be RAISED heavy +HYP: there is a stone to be RACED heavy +Eval: S +``` + +## 5. Evaluate the pipeline on VoxForge dataset + +Now we use the pipeline we created in step 1.2. This time with German language dataset from VoxForge. + +### 5.1. Create a list of audio paths + +Download an archive from http://www.repository.voxforge1.org/downloads/de/Trunk/Audio/Main/16kHz_16bit/, and extract it to your local file system, then run the following to generate the file list. + +```bash +python parse_voxforge.py > ./flist-de.txt +``` + +The list should look like + +```bash +head flist-de.txt +de5-001 /datasets/voxforge/de/guenter-20140214-afn/wav/de5-001.wav ES SOLL ETWA FÜNFZIGTAUSEND VERSCHIEDENE SORTEN GEBEN +``` + +### 5.2. Run the application and score WER + +This process is same as the Librispeech example. We just use the pipeline with the Germany model and file list of Germany dataset. Refer to the corresponding ssection in Librispeech evaluation.. + +```bash +../build/speech_recognition/transcribe_list ./pipeline-hf ./flist-de.txt +``` + +Then + +```bash +# in the output directory +sclite -r ref.trn -h hyp.trn -i wsj -o pralign -o sum +``` + +You can find the detail of evalauation result in PRA. + +``` +id: (guenter-20140214-afn/mfc/de5-012) +Scores: (#C #S #D #I) 4 1 1 0 +REF: die ausgaben kÖnnen gigantisch STEIGE N +HYP: die ausgaben kÖnnen gigantisch ****** STEIGEN +Eval: D S +``` diff --git a/examples/libtorchaudio/speech_recognition/build_pipeline_from_fairseq.py b/examples/libtorchaudio/speech_recognition/build_pipeline_from_fairseq.py new file mode 100644 index 0000000000000000000000000000000000000000..a6da0ae1e1f7d0857f4f2af4e4683e692fd54eed --- /dev/null +++ b/examples/libtorchaudio/speech_recognition/build_pipeline_from_fairseq.py @@ -0,0 +1,182 @@ +#!/usr/bin/evn python3 +"""Build Speech Recognition pipeline based on fairseq's wav2vec2.0 and dump it to TorchScript file. + +To use this script, you need `fairseq`. +""" +import os +import argparse +import logging + +import torch +from torch.utils.mobile_optimizer import optimize_for_mobile +import torchaudio +from torchaudio.models.wav2vec2.utils.import_fairseq import import_fairseq_model +import fairseq + +from greedy_decoder import Decoder + +_LG = logging.getLogger(__name__) + + +def _parse_args(): + parser = argparse.ArgumentParser( + description=__doc__, + ) + parser.add_argument( + '--model-file', + required=True, + help='Path to the input pretrained weight file.' + ) + parser.add_argument( + '--dict-dir', + help=( + 'Path to the directory in which `dict.ltr.txt` file is found. ' + 'Required only when the model is finetuned.' + ) + ) + parser.add_argument( + '--output-path', + help='Path to the directory, where the TorchScript-ed pipelines are saved.', + ) + parser.add_argument( + '--test-file', + help='Path to a test audio file.', + ) + parser.add_argument( + '--debug', + action='store_true', + help=( + 'When enabled, individual components are separately tested ' + 'for the numerical compatibility and TorchScript compatibility.' + ) + ) + parser.add_argument( + '--quantize', + action='store_true', + help='Apply quantization to model.' + ) + parser.add_argument( + '--optimize-for-mobile', + action='store_true', + help='Apply optmization for mobile.' + ) + return parser.parse_args() + + +class Loader(torch.nn.Module): + def forward(self, audio_path: str) -> torch.Tensor: + waveform, sample_rate = torchaudio.load(audio_path) + if sample_rate != 16000: + waveform = torchaudio.functional.resample(waveform, float(sample_rate), 16000.) + return waveform + + +class Encoder(torch.nn.Module): + def __init__(self, encoder: torch.nn.Module): + super().__init__() + self.encoder = encoder + + def forward(self, waveform: torch.Tensor) -> torch.Tensor: + result, _ = self.encoder(waveform) + return result[0] + + +def _get_decoder(): + labels = [ + "", + "", + "", + "", + "|", + "E", + "T", + "A", + "O", + "N", + "I", + "H", + "S", + "R", + "D", + "L", + "U", + "M", + "W", + "C", + "F", + "G", + "Y", + "P", + "B", + "V", + "K", + "'", + "X", + "J", + "Q", + "Z", + ] + return Decoder(labels) + + +def _load_fairseq_model(input_file, data_dir=None): + overrides = {} + if data_dir: + overrides['data'] = data_dir + + model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( + [input_file], arg_overrides=overrides + ) + model = model[0] + return model + + +def _get_model(model_file, dict_dir): + original = _load_fairseq_model(model_file, dict_dir) + model = import_fairseq_model(original.w2v_encoder) + return model + + +def _main(): + args = _parse_args() + _init_logging(args.debug) + loader = Loader() + model = _get_model(args.model_file, args.dict_dir).eval() + encoder = Encoder(model) + decoder = _get_decoder() + _LG.info(encoder) + + if args.quantize: + _LG.info('Quantizing the model') + model.encoder.transformer.pos_conv_embed.__prepare_scriptable__() + encoder = torch.quantization.quantize_dynamic( + encoder, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8) + _LG.info(encoder) + + # test + if args.test_file: + _LG.info('Testing with %s', args.test_file) + waveform = loader(args.test_file) + emission = encoder(waveform) + transcript = decoder(emission) + _LG.info(transcript) + + torch.jit.script(loader).save(os.path.join(args.output_path, 'loader.zip')) + torch.jit.script(decoder).save(os.path.join(args.output_path, 'decoder.zip')) + scripted = torch.jit.script(encoder) + if args.optimize_for_mobile: + scripted = optimize_for_mobile(scripted) + scripted.save(os.path.join(args.output_path, 'encoder.zip')) + + +def _init_logging(debug=False): + level = logging.DEBUG if debug else logging.INFO + format_ = ( + '%(message)s' if not debug else + '%(asctime)s: %(levelname)7s: %(funcName)10s: %(message)s' + ) + logging.basicConfig(level=level, format=format_) + + +if __name__ == '__main__': + _main() diff --git a/examples/libtorchaudio/speech_recognition/build_pipeline_from_huggingface_transformers.py b/examples/libtorchaudio/speech_recognition/build_pipeline_from_huggingface_transformers.py new file mode 100644 index 0000000000000000000000000000000000000000..10323d96f784422a8fc5ed3178bf594d17d597dc --- /dev/null +++ b/examples/libtorchaudio/speech_recognition/build_pipeline_from_huggingface_transformers.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +import argparse +import logging +import os + +import torch +import torchaudio +from torchaudio.models.wav2vec2.utils.import_huggingface import import_huggingface_model +from greedy_decoder import Decoder + +_LG = logging.getLogger(__name__) + + +def _parse_args(): + parser = argparse.ArgumentParser( + description=__doc__, + ) + parser.add_argument( + '--model', + required=True, + help='Path to the input pretrained weight file.' + ) + parser.add_argument( + '--output-path', + help='Path to the directory, where the Torchscript-ed pipelines are saved.', + ) + parser.add_argument( + '--test-file', + help='Path to a test audio file.', + ) + parser.add_argument( + '--quantize', + action='store_true', + help='Quantize the model.', + ) + parser.add_argument( + '--debug', + action='store_true', + help=( + 'When enabled, individual components are separately tested ' + 'for the numerical compatibility and TorchScript compatibility.' + ) + ) + return parser.parse_args() + + +class Loader(torch.nn.Module): + def forward(self, audio_path: str) -> torch.Tensor: + waveform, sample_rate = torchaudio.load(audio_path) + if sample_rate != 16000: + waveform = torchaudio.functional.resample(waveform, float(sample_rate), 16000.) + return waveform + + +class Encoder(torch.nn.Module): + def __init__(self, encoder: torch.nn.Module): + super().__init__() + self.encoder = encoder + + def forward(self, waveform: torch.Tensor) -> torch.Tensor: + result, _ = self.encoder(waveform) + return result[0] + + +def _get_model(model_id): + from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor + tokenizer = Wav2Vec2Processor.from_pretrained(model_id).tokenizer + labels = [k for k, v in sorted(tokenizer.get_vocab().items(), key=lambda kv: kv[1])] + original = Wav2Vec2ForCTC.from_pretrained(model_id) + model = import_huggingface_model(original) + return model.eval(), labels + + +def _get_decoder(labels): + return Decoder(labels) + + +def _main(): + args = _parse_args() + _init_logging(args.debug) + _LG.info('Loading model: %s', args.model) + model, labels = _get_model(args.model) + _LG.info('Labels: %s', labels) + _LG.info('Building pipeline') + loader = Loader() + encoder = Encoder(model) + decoder = _get_decoder(labels) + _LG.info(encoder) + + if args.quantize: + _LG.info('Quantizing the model') + model.encoder.transformer.pos_conv_embed.__prepare_scriptable__() + encoder = torch.quantization.quantize_dynamic( + encoder, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8) + _LG.info(encoder) + + # test + if args.test_file: + _LG.info('Testing with %s', args.test_file) + waveform = loader(args.test_file) + emission = encoder(waveform) + transcript = decoder(emission) + _LG.info(transcript) + + torch.jit.script(loader).save(os.path.join(args.output_path, 'loader.zip')) + torch.jit.script(encoder).save(os.path.join(args.output_path, 'encoder.zip')) + torch.jit.script(decoder).save(os.path.join(args.output_path, 'decoder.zip')) + + +def _init_logging(debug=False): + level = logging.DEBUG if debug else logging.INFO + format_ = ( + '%(message)s' if not debug else + '%(asctime)s: %(levelname)7s: %(funcName)10s: %(message)s' + ) + logging.basicConfig(level=level, format=format_) + + +if __name__ == '__main__': + _main() diff --git a/examples/libtorchaudio/speech_recognition/greedy_decoder.py b/examples/libtorchaudio/speech_recognition/greedy_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3303c330704d76cc0e92ccf762f965cdaaf26f90 --- /dev/null +++ b/examples/libtorchaudio/speech_recognition/greedy_decoder.py @@ -0,0 +1,28 @@ +import torch + + +class Decoder(torch.nn.Module): + def __init__(self, labels): + super().__init__() + self.labels = labels + + def forward(self, logits: torch.Tensor) -> str: + """Given a sequence logits over labels, get the best path string + + Args: + logits (Tensor): Logit tensors. Shape `[num_seq, num_label]`. + + Returns: + str: The resulting transcript + """ + best_path = torch.argmax(logits, dim=-1) # [num_seq,] + best_path = torch.unique_consecutive(best_path, dim=-1) + hypothesis = '' + for i in best_path: + char = self.labels[i] + if char in ['', '']: + continue + if char == '|': + char = ' ' + hypothesis += char + return hypothesis diff --git a/examples/libtorchaudio/speech_recognition/parse_librispeech.py b/examples/libtorchaudio/speech_recognition/parse_librispeech.py new file mode 100644 index 0000000000000000000000000000000000000000..dcd7aaf938eac0dfccf2401cd83c4a8aaa658f05 --- /dev/null +++ b/examples/libtorchaudio/speech_recognition/parse_librispeech.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +"""Parse a directory contains Librispeech dataset. + +Recursively search for "*.trans.txt" file in the given directory and print out + +`\\t\\t` + +example: python parse_librispeech.py LibriSpeech/test-clean + + 1089-134691-0000\t/LibriSpeech/test-clean/1089/134691/1089-134691-0000.flac\tHE COULD WAIT NO LONGER + ... + +Dataset can be obtained from https://www.openslr.org/12 +""" +import argparse +from pathlib import Path + + +def _parse_args(): + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + 'input_dir', + type=Path, + help='Directory where `*.trans.txt` files are searched.' + ) + return parser.parse_args() + + +def _parse_transcript(path): + with open(path) as trans_fileobj: + for line in trans_fileobj: + line = line.strip() + if line: + yield line.split(' ', maxsplit=1) + + +def _parse_directory(root_dir: Path): + for trans_file in root_dir.glob('**/*.trans.txt'): + trans_dir = trans_file.parent + for id_, transcription in _parse_transcript(trans_file): + audio_path = trans_dir / f'{id_}.flac' + yield id_, audio_path, transcription + + +def _main(): + args = _parse_args() + for id_, path, transcription in _parse_directory(args.input_dir): + print(f'{id_}\t{path}\t{transcription}') + + +if __name__ == '__main__': + _main() diff --git a/examples/libtorchaudio/speech_recognition/parse_voxforge.py b/examples/libtorchaudio/speech_recognition/parse_voxforge.py new file mode 100644 index 0000000000000000000000000000000000000000..ea88c60851649e144fa501876d9293bca47591fc --- /dev/null +++ b/examples/libtorchaudio/speech_recognition/parse_voxforge.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python +"""Parse a directory contains VoxForge dataset. + +Recursively search for "PROMPTS" file in the given directory and print out + +`\\t\\t` + +example: python parse_voxforge.py voxforge/de/Helge-20150608-aku + + de5-001\t/datasets/voxforge/de/guenter-20140214-afn/wav/de5-001.wav\tES SOLL ETWA FÜNFZIGTAUSEND VERSCHIEDENE SORTEN GEBEN + ... + +Dataset can be obtained from http://www.repository.voxforge1.org/downloads/de/Trunk/Audio/Main/16kHz_16bit/ +""" # noqa: E501 +import os +import argparse +from pathlib import Path + + +def _parse_args(): + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + 'input_dir', + type=Path, + help='Directory where `*.trans.txt` files are searched.' + ) + return parser.parse_args() + + +def _parse_prompts(path): + base_dir = path.parent.parent + with open(path) as trans_fileobj: + for line in trans_fileobj: + line = line.strip() + if not line: + continue + + id_, transcript = line.split(' ', maxsplit=1) + if not transcript: + continue + + transcript = transcript.upper() + filename = id_.split('/')[-1] + audio_path = base_dir / 'wav' / f'{filename}.wav' + if os.path.exists(audio_path): + yield id_, audio_path, transcript + + +def _parse_directory(root_dir: Path): + for prompt_file in root_dir.glob('**/PROMPTS'): + try: + yield from _parse_prompts(prompt_file) + except UnicodeDecodeError: + pass + + +def _main(): + args = _parse_args() + for id_, path, transcription in _parse_directory(args.input_dir): + print(f'{id_}\t{path}\t{transcription}') + + +if __name__ == '__main__': + _main() diff --git a/examples/libtorchaudio/speech_recognition/transcribe.cpp b/examples/libtorchaudio/speech_recognition/transcribe.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e6d65b17e4250fe4ba3d505128be3c716becba6c --- /dev/null +++ b/examples/libtorchaudio/speech_recognition/transcribe.cpp @@ -0,0 +1,38 @@ +#include + +int main(int argc, char* argv[]) { + if (argc != 3) { + std::cerr << "Usage: " << argv[0] << " " << std::endl; + return -1; + } + + torch::jit::script::Module loader, encoder, decoder; + std::cout << "Loading module from: " << argv[1] << std::endl; + try { + loader = torch::jit::load(std::string(argv[1]) + "/loader.zip"); + } catch (const c10::Error &error) { + std::cerr << "Failed to load the module:" << error.what() << std::endl; + return -1; + } + try { + encoder = torch::jit::load(std::string(argv[1]) + "/encoder.zip"); + } catch (const c10::Error &error) { + std::cerr << "Failed to load the module:" << error.what() << std::endl; + return -1; + } + try { + decoder = torch::jit::load(std::string(argv[1]) + "/decoder.zip"); + } catch (const c10::Error &error) { + std::cerr << "Failed to load the module:" << error.what() << std::endl; + return -1; + } + + std::cout << "Loading the audio" << std::endl; + auto waveform = loader.forward({c10::IValue(argv[2])}); + std::cout << "Running inference" << std::endl; + auto emission = encoder.forward({waveform}); + std::cout << "Generating the transcription" << std::endl; + auto result = decoder.forward({emission}); + std::cout << result.toString()->string() << std::endl; + std::cout << "Done." << std::endl; +} diff --git a/examples/libtorchaudio/speech_recognition/transcribe_list.cpp b/examples/libtorchaudio/speech_recognition/transcribe_list.cpp new file mode 100644 index 0000000000000000000000000000000000000000..458a98f568bd4080ec1c598e1e7a8df4a6fb3242 --- /dev/null +++ b/examples/libtorchaudio/speech_recognition/transcribe_list.cpp @@ -0,0 +1,66 @@ +#include +#include + + +int main(int argc, char* argv[]) { + if (argc != 4) { + std::cerr << "Usage: " << argv[0] << " \n" << std::endl; + std::cerr << " is `\t\t`" << std::endl; + return -1; + } + + torch::jit::script::Module loader, encoder, decoder; + std::cout << "Loading module from: " << argv[1] << std::endl; + try { + loader = torch::jit::load(std::string(argv[1]) + "/loader.zip"); + } catch (const c10::Error &error) { + std::cerr << "Failed to load the module:" << error.what() << std::endl; + return -1; + } + try { + encoder = torch::jit::load(std::string(argv[1]) + "/encoder.zip"); + } catch (const c10::Error &error) { + std::cerr << "Failed to load the module:" << error.what() << std::endl; + return -1; + } + try { + decoder = torch::jit::load(std::string(argv[1]) + "/decoder.zip"); + } catch (const c10::Error &error) { + std::cerr << "Failed to load the module:" << error.what() << std::endl; + return -1; + } + + std::ifstream input_file(argv[2]); + std::string output_dir(argv[3]); + std::ofstream output_ref(output_dir + "/ref.trn"); + std::ofstream output_hyp(output_dir + "/hyp.trn"); + std::string line; + std::chrono::milliseconds t_encode(0); + std::chrono::milliseconds t_decode(0); + while(std::getline(input_file, line)) { + std::istringstream iline(line); + std::string id; + std::string path; + std::string reference; + std::getline(iline, id, '\t'); + std::getline(iline, path, '\t'); + std::getline(iline, reference, '\t'); + + auto waveform = loader.forward({c10::IValue(path)}); + std::chrono::steady_clock::time_point t0 = std::chrono::steady_clock::now(); + auto emission = encoder.forward({waveform}); + std::chrono::steady_clock::time_point t1 = std::chrono::steady_clock::now(); + auto result = decoder.forward({emission}); + std::chrono::steady_clock::time_point t2 = std::chrono::steady_clock::now(); + + t_encode += std::chrono::duration_cast(t1 - t0); + t_decode += std::chrono::duration_cast(t2 - t1); + + auto hypothesis = result.toString()->string(); + output_hyp << hypothesis << " (" << id << ")" << std::endl; + output_ref << reference << " (" << id << ")" << std::endl; + std::cout << id << '\t' << hypothesis << std::endl; + } + std::cout << "Time (encode): " << t_encode.count() << " [ms]" << std::endl; + std::cout << "Time (decode): " << t_decode.count() << " [ms]" << std::endl; +} diff --git a/examples/pipeline_tacotron2/README.md b/examples/pipeline_tacotron2/README.md new file mode 100644 index 0000000000000000000000000000000000000000..aa90d060f1c586b5d104b5066ecfab5763ba947b --- /dev/null +++ b/examples/pipeline_tacotron2/README.md @@ -0,0 +1,258 @@ +This is an example pipeline for text-to-speech using Tacotron2. + +Here is a [colab example](https://colab.research.google.com/drive/1MPcn1_G5lKozxZ7v8b9yucOD5X5cLK4j?usp=sharing) +that shows how the text-to-speech pipeline is used during inference with the built-in pretrained models. + +## Install required packages + +Required packages +```bash +pip install librosa tqdm inflect joblib +``` + +To use tensorboard +```bash +pip install tensorboard pillow +``` + +## Training Tacotron2 with character as input + +The training of Tacotron2 can be invoked with the following command. + +```bash +python train.py \ + --learning-rate 1e-3 \ + --epochs 1501 \ + --anneal-steps 500 1000 1500 \ + --anneal-factor 0.1 \ + --batch-size 96 \ + --weight-decay 1e-6 \ + --grad-clip 1.0 \ + --text-preprocessor english_characters \ + --logging-dir ./logs \ + --checkpoint-path ./ckpt.pth \ + --dataset-path ./ +``` + +The training script will use all GPUs that is available, please set the +environment variable `CUDA_VISIBLE_DEVICES` if you don't want all GPUs to be used. +The newest checkpoint will be saved to `./ckpt.pth` and the checkpoint with the best validation +loss will be saved to `./best_ckpt.pth`. +The training log will be saved to `./logs/train.log` and the tensorboard results will also +be in `./logs`. + +If `./ckpt.pth` already exist, this script will automatically load the file and try to continue +training from the checkpoint. + +This command takes around 36 hours to train on 8 NVIDIA Tesla V100 GPUs. + +To train the Tacotron2 model to work with the [pretrained wavernn](https://pytorch.org/audio/main/models.html#id10) +with checkpoint_name `"wavernn_10k_epochs_8bits_ljspeech"`, please run the following command instead. + +```bash +python train.py + --learning-rate 1e-3 \ + --epochs 1501 \ + --anneal-steps 500 1000 1500 \ + --anneal-factor 0.1 \ + --sample-rate 22050 \ + --n-fft 2048 \ + --hop-length 275 \ + --win-length 1100 \ + --mel-fmin 40 \ + --mel-fmax 11025 \ + --batch-size 96 \ + --weight-decay 1e-6 \ + --grad-clip 1.0 \ + --text-preprocessor english_characters \ + --logging-dir ./wavernn_logs \ + --checkpoint-path ./ckpt_wavernn.pth \ + --dataset-path ./ +``` + + +## Training Tacotron2 with phoneme as input + +#### Dependencies + +This example use the [DeepPhonemizer](https://github.com/as-ideas/DeepPhonemizer) as +the phonemizer (the function to turn text into phonemes), +please install it with the following command (the code is tested with version 0.0.15). + +```bash +pip install deep-phonemizer==0.0.15 +``` + +Then download the model weights from [their website](https://github.com/as-ideas/DeepPhonemizer) + +The link to the checkpoint that is tested with this example is +[https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_forward.pt](https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_forward.pt). + +#### Running training script + +The training of Tacotron2 with english phonemes as input can be invoked with the following command. + +```bash +python train.py \ + --workers 12 \ + --learning-rate 1e-3 \ + --epochs 1501 \ + --anneal-steps 500 1000 1500 \ + --anneal-factor 0.1 \ + --batch-size 96 \ + --weight-decay 1e-6 \ + --grad-clip 1.0 \ + --text-preprocessor english_phonemes \ + --phonemizer DeepPhonemizer \ + --phonemizer-checkpoint ./en_us_cmudict_forward.pt \ + --cmudict-root ./ \ + --logging-dir ./english_phonemes_logs \ + --checkpoint-path ./english_phonemes_ckpt.pth \ + --dataset-path ./ +``` + +Similar to the previous examples, this command will save the log in the directory `./english_phonemes_logs` +and the checkpoint will be saved to `./english_phonemes_ckpt.pth`. + + +To train the Tacotron2 model with english phonemes that works with the +[pretrained wavernn](https://pytorch.org/audio/main/models.html#id10) +with checkpoint_name `"wavernn_10k_epochs_8bits_ljspeech"`, please run the following command. + +```bash +python train.py \ + --workers 12 \ + --learning-rate 1e-3 \ + --epochs 1501 \ + --anneal-steps 500 1000 1500 \ + --anneal-factor 0.1 \ + --sample-rate 22050 \ + --n-fft 2048 \ + --hop-length 275 \ + --win-length 1100 \ + --mel-fmin 40 \ + --mel-fmax 11025 \ + --batch-size 96 \ + --weight-decay 1e-6 \ + --grad-clip 1.0 \ + --text-preprocessor english_phonemes \ + --phonemizer DeepPhonemizer \ + --phonemizer-checkpoint ./en_us_cmudict_forward.pt \ + --cmudict-root ./ \ + --logging-dir ./english_phonemes_wavernn_logs \ + --checkpoint-path ./english_phonemes_wavernn_ckpt.pth \ + --dataset-path ./ +``` + + +## Text-to-speech pipeline + +Here we present an example of how to use Tacotron2 to generate audio from text. +The text-to-speech pipeline goes as follows: +1. text preprocessing: encoder the text into list of symbols (the symbols can represent characters, phonemes, etc.) +2. spectrogram generation: after retrieving the list of symbols, we feed this list to a Tacotron2 model and the model +will output the mel spectrogram. +3. time-domain conversion: when the mel spectrogram is generated, we need to convert it into audio with a vocoder. +Currently, there are three vocoders being supported in this script, which includes the +[WaveRNN](https://pytorch.org/audio/stable/models/wavernn.html), +[Griffin-Lim](https://pytorch.org/audio/stable/transforms.html#griffinlim), and +[Nvidia's WaveGlow](https://pytorch.org/hub/nvidia_deeplearningexamples_tacotron2/). + +The spectro parameters including `n-fft`, `mel-fmin`, `mel-fmax` should be set to the values +used during the training of Tacotron2. + + +#### Pretrained WaveRNN as the Vocoder + +The following command will generate a waveform to `./outputs.wav` +with the text "Hello world!" using WaveRNN as the vocoder. + +```bash +python inference.py --checkpoint-path ${model_path} \ + --vocoder wavernn \ + --n-fft 2048 \ + --mel-fmin 40 \ + --mel-fmax 11025 \ + --input-text "Hello world!" \ + --text-preprocessor english_characters \ + --output-path "./outputs.wav" +``` + +If you want to generate a waveform with a different text with phonemes +as the input to Tacotron2, please use the `--text-preprocessor english_phonemes`. +The following is an example. +(Remember to install the [DeepPhonemizer](https://github.com/as-ideas/DeepPhonemizer) +and download their pretrained weights. + +```bash +python inference.py --checkpoint-path ${model_path} \ + --vocoder wavernn \ + --n-fft 2048 \ + --mel-fmin 40 \ + --mel-fmax 11025 \ + --input-text "Hello world!" \ + --text-preprocessor english_phonemes \ + --phonimizer DeepPhonemizer \ + --phoimizer-checkpoint ./en_us_cmudict_forward.pt \ + --cmudict-root ./ \ + --output-path "./outputs.wav" +``` + +To use torchaudio pretrained models, please see the following example command. +For Tacotron2, we use the checkpoint named `"tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech"`, and +for WaveRNN, we use the checkpoint named `"wavernn_10k_epochs_8bits_ljspeech"`. +See https://pytorch.org/audio/stable/models.html for more checkpoint options for Tacotron2 and WaveRNN. + +```bash +python inference.py \ + --checkpoint-path tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech \ + --wavernn-checkpoint-path wavernn_10k_epochs_8bits_ljspeech \ + --vocoder wavernn \ + --n-fft 2048 \ + --mel-fmin 40 \ + --mel-fmax 11025 \ + --input-text "Hello world!" \ + --text-preprocessor english_phonemes \ + --phonimizer DeepPhonemizer \ + --phoimizer-checkpoint ./en_us_cmudict_forward.pt \ + --cmudict-root ./ \ + --output-path "./outputs.wav" +``` + +#### Griffin-Lim's algorithm as the Vocoder + +The following command will generate a waveform to `./outputs.wav` +with the text "Hello world!" using Griffin-Lim's algorithm as the vocoder. + +```bash +python inference.py --checkpoint-path ${model_path} \ + --vocoder griffin_lim \ + --n-fft 1024 \ + --mel-fmin 0 \ + --mel-fmax 8000 \ + --input-text "Hello world!" \ + --text-preprocessor english_characters \ + --output-path "./outputs.wav" +``` + + +#### Nvidia's Waveglow as the Vocoder + +The following command will generate a waveform to `./outputs.wav` +with the text `"Hello world!"` using Nvidia's WaveGlow as the vocoder. +The WaveGlow is loaded using the following torchhub's API. + +```python +torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_waveglow', model_math='fp16') +``` + +```bash +python inference.py --checkpoint-path ${model_path} \ + --vocoder nvidia_waveglow \ + --n-fft 1024 \ + --mel-fmin 0 \ + --mel-fmax 8000 \ + --input-text "Hello world!" \ + --text-preprocessor english_characters \ + --output-path "./outputs.wav" +``` diff --git a/examples/pipeline_tacotron2/datasets.py b/examples/pipeline_tacotron2/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..a5a68203836b66256ff61307dcc487708a717326 --- /dev/null +++ b/examples/pipeline_tacotron2/datasets.py @@ -0,0 +1,171 @@ +# ***************************************************************************** +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the +# names of its contributors may be used to endorse or promote products +# derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# ***************************************************************************** + +from typing import Tuple, Callable, List + +import torch +from torch import Tensor + +from torch.utils.data.dataset import random_split +from torchaudio.datasets import LJSPEECH + + +class SpectralNormalization(torch.nn.Module): + def forward(self, input): + return torch.log(torch.clamp(input, min=1e-5)) + + +class InverseSpectralNormalization(torch.nn.Module): + def forward(self, input): + return torch.exp(input) + + +class MapMemoryCache(torch.utils.data.Dataset): + r"""Wrap a dataset so that, whenever a new item is returned, it is saved to memory. + """ + + def __init__(self, dataset): + self.dataset = dataset + self._cache = [None] * len(dataset) + + def __getitem__(self, n): + if self._cache[n] is not None: + return self._cache[n] + + item = self.dataset[n] + self._cache[n] = item + + return item + + def __len__(self): + return len(self.dataset) + + +class Processed(torch.utils.data.Dataset): + def __init__(self, dataset, transforms, text_preprocessor): + self.dataset = dataset + self.transforms = transforms + self.text_preprocessor = text_preprocessor + + def __getitem__(self, key): + item = self.dataset[key] + return self.process_datapoint(item) + + def __len__(self): + return len(self.dataset) + + def process_datapoint(self, item): + melspec = self.transforms(item[0]) + text_norm = torch.IntTensor(self.text_preprocessor(item[2])) + return text_norm, torch.squeeze(melspec, 0) + + +def split_process_dataset(dataset: str, + file_path: str, + val_ratio: float, + transforms: Callable, + text_preprocessor: Callable[[str], List[int]], + ) -> Tuple[torch.utils.data.Dataset, torch.utils.data.Dataset]: + """Returns the Training and validation datasets. + + Args: + dataset (str): The dataset to use. Avaliable options: [`'ljspeech'`] + file_path (str): Path to the data. + val_ratio (float): Path to the data. + transforms (callable): A function/transform that takes in a waveform and + returns a transformed waveform (mel spectrogram in this example). + text_preprocess (callable): A function that takes in a string and + returns a list of integers representing each of the symbol in the string. + + Returns: + train_dataset (`torch.utils.data.Dataset`): The training set. + val_dataset (`torch.utils.data.Dataset`): The validation set. + """ + if dataset == 'ljspeech': + data = LJSPEECH(root=file_path, download=False) + + val_length = int(len(data) * val_ratio) + lengths = [len(data) - val_length, val_length] + train_dataset, val_dataset = random_split(data, lengths) + else: + raise ValueError(f"Expected datasets: `ljspeech`, but found {dataset}") + + train_dataset = Processed(train_dataset, transforms, text_preprocessor) + val_dataset = Processed(val_dataset, transforms, text_preprocessor) + + train_dataset = MapMemoryCache(train_dataset) + val_dataset = MapMemoryCache(val_dataset) + + return train_dataset, val_dataset + + +def text_mel_collate_fn(batch: Tuple[Tensor, Tensor], + n_frames_per_step: int = 1) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """The collate function padding and adjusting the data based on `n_frames_per_step`. + Modified from https://github.com/NVIDIA/DeepLearningExamples + + Args: + batch (tuple of two tensors): the first tensor is the mel spectrogram with shape + (n_batch, n_mels, n_frames), the second tensor is the text with shape (n_batch, ). + n_frames_per_step (int, optional): The number of frames to advance every step. + + Returns: + text_padded (Tensor): The input text to Tacotron2 with shape (n_batch, max of ``text_lengths``). + text_lengths (Tensor): The length of each text with shape (n_batch). + mel_specgram_padded (Tensor): The target mel spectrogram + with shape (n_batch, n_mels, max of ``mel_specgram_lengths``) + mel_specgram_lengths (Tensor): The length of each mel spectrogram with shape (n_batch). + gate_padded (Tensor): The ground truth gate output + with shape (n_batch, max of ``mel_specgram_lengths``) + """ + text_lengths, ids_sorted_decreasing = torch.sort( + torch.LongTensor([len(x[0]) for x in batch]), dim=0, descending=True) + max_input_len = text_lengths[0] + + text_padded = torch.zeros((len(batch), max_input_len), dtype=torch.int64) + for i in range(len(ids_sorted_decreasing)): + text = batch[ids_sorted_decreasing[i]][0] + text_padded[i, :text.size(0)] = text + + # Right zero-pad mel-spec + num_mels = batch[0][1].size(0) + max_target_len = max([x[1].size(1) for x in batch]) + if max_target_len % n_frames_per_step != 0: + max_target_len += n_frames_per_step - max_target_len % n_frames_per_step + assert max_target_len % n_frames_per_step == 0 + + # include mel padded and gate padded + mel_specgram_padded = torch.zeros((len(batch), num_mels, max_target_len), dtype=torch.float32) + gate_padded = torch.zeros((len(batch), max_target_len), dtype=torch.float32) + mel_specgram_lengths = torch.LongTensor(len(batch)) + for i in range(len(ids_sorted_decreasing)): + mel = batch[ids_sorted_decreasing[i]][1] + mel_specgram_padded[i, :, :mel.size(1)] = mel + mel_specgram_lengths[i] = mel.size(1) + gate_padded[i, mel.size(1) - 1:] = 1 + + return text_padded, text_lengths, mel_specgram_padded, mel_specgram_lengths, gate_padded diff --git a/examples/pipeline_tacotron2/inference.py b/examples/pipeline_tacotron2/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..a89fffb5cfb2e130cdab627080674868177a3896 --- /dev/null +++ b/examples/pipeline_tacotron2/inference.py @@ -0,0 +1,353 @@ +""" +Text-to-speech pipeline using Tacotron2. +""" + +from functools import partial +import argparse +import os +import random +import sys + +import torch +import torchaudio +import numpy as np +from torchaudio.models import Tacotron2 +from torchaudio.models import tacotron2 as pretrained_tacotron2 + +from utils import prepare_input_sequence +from datasets import InverseSpectralNormalization +from text.text_preprocessing import ( + available_symbol_set, + available_phonemizers, + get_symbol_list, + text_to_sequence, +) + + +def parse_args(): + r""" + Parse commandline arguments. + """ + from torchaudio.models.tacotron2 import _MODEL_CONFIG_AND_URLS as tacotron2_config_and_urls + from torchaudio.models.wavernn import _MODEL_CONFIG_AND_URLS as wavernn_config_and_urls + + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + '--checkpoint-name', + type=str, + default=None, + choices=list(tacotron2_config_and_urls.keys()), + help='[string] The name of the checkpoint to load.' + ) + parser.add_argument( + '--checkpoint-path', + type=str, + default=None, + help='[string] Path to the checkpoint file.' + ) + parser.add_argument( + '--output-path', + type=str, + default="./audio.wav", + help='[string] Path to the output .wav file.' + ) + parser.add_argument( + '--input-text', + '-i', + type=str, + default="Hello world", + help='[string] Type in something here and TTS will generate it!' + ) + parser.add_argument( + '--vocoder', + default='nvidia_waveglow', + choices=['griffin_lim', 'wavernn', 'nvidia_waveglow'], + type=str, + help="Select the vocoder to use.", + ) + parser.add_argument( + "--jit", + default=False, + action="store_true", + help="If used, the model and inference function is jitted." + ) + + preprocessor = parser.add_argument_group('text preprocessor setup') + preprocessor.add_argument( + '--text-preprocessor', + default='english_characters', + type=str, + choices=available_symbol_set, + help='select text preprocessor to use.' + ) + preprocessor.add_argument( + '--phonemizer', + default="DeepPhonemizer", + type=str, + choices=available_phonemizers, + help='select phonemizer to use, only used when text-preprocessor is "english_phonemes"' + ) + preprocessor.add_argument( + '--phonemizer-checkpoint', + default="./en_us_cmudict_forward.pt", + type=str, + help='the path or name of the checkpoint for the phonemizer, ' + 'only used when text-preprocessor is "english_phonemes"' + ) + preprocessor.add_argument( + '--cmudict-root', + default="./", + type=str, + help='the root directory for storing CMU dictionary files' + ) + + audio = parser.add_argument_group('audio parameters') + audio.add_argument( + '--sample-rate', + default=22050, + type=int, + help='Sampling rate' + ) + audio.add_argument( + '--n-fft', + default=1024, + type=int, + help='Filter length for STFT' + ) + audio.add_argument( + '--n-mels', + default=80, + type=int, + help='' + ) + audio.add_argument( + '--mel-fmin', + default=0.0, + type=float, + help='Minimum mel frequency' + ) + audio.add_argument( + '--mel-fmax', + default=8000.0, + type=float, + help='Maximum mel frequency' + ) + + # parameters for WaveRNN + wavernn = parser.add_argument_group('WaveRNN parameters') + wavernn.add_argument( + '--wavernn-checkpoint-name', + default="wavernn_10k_epochs_8bits_ljspeech", + choices=list(wavernn_config_and_urls.keys()), + help="Select the WaveRNN checkpoint." + ) + wavernn.add_argument( + "--wavernn-loss", + default="crossentropy", + choices=["crossentropy"], + type=str, + help="The type of loss the WaveRNN pretrained model is trained on.", + ) + wavernn.add_argument( + "--wavernn-no-batch-inference", + default=False, + action="store_true", + help="Don't use batch inference for WaveRNN inference." + ) + wavernn.add_argument( + "--wavernn-no-mulaw", + default=False, + action="store_true", + help="Don't use mulaw decoder to decode the signal." + ) + wavernn.add_argument( + "--wavernn-batch-timesteps", + default=11000, + type=int, + help="The time steps for each batch. Only used when batch inference is used", + ) + wavernn.add_argument( + "--wavernn-batch-overlap", + default=550, + type=int, + help="The overlapping time steps between batches. Only used when batch inference is used", + ) + + return parser + + +def unwrap_distributed(state_dict): + r"""torch.distributed.DistributedDataParallel wraps the model with an additional "module.". + This function unwraps this layer so that the weights can be loaded on models with a single GPU. + + Args: + state_dict: Original state_dict. + + Return: + unwrapped_state_dict: Unwrapped state_dict. + """ + + return {k.replace('module.', ''): v for k, v in state_dict.items()} + + +def nvidia_waveglow_vocode(mel_specgram, device, jit=False): + waveglow = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_waveglow', model_math='fp16') + waveglow = waveglow.remove_weightnorm(waveglow) + waveglow = waveglow.to(device) + waveglow.eval() + + if args.jit: + raise ValueError("Vocoder option `nvidia_waveglow is not jittable.") + + with torch.no_grad(): + waveform = waveglow.infer(mel_specgram).cpu() + + return waveform + + +def wavernn_vocode(mel_specgram, wavernn_checkpoint_name, wavernn_loss, wavernn_no_mulaw, + wavernn_no_batch_inference, wavernn_batch_timesteps, wavernn_batch_overlap, + device, jit): + from torchaudio.models import wavernn + sys.path.append(os.path.join(os.path.dirname(__file__), "../pipeline_wavernn")) + from wavernn_inference_wrapper import WaveRNNInferenceWrapper + from processing import NormalizeDB + + wavernn_model = wavernn(wavernn_checkpoint_name).eval().to(device) + wavernn_inference_model = WaveRNNInferenceWrapper(wavernn_model) + + if jit: + wavernn_inference_model = torch.jit.script(wavernn_inference_model) + + # WaveRNN spectro setting for default checkpoint + # n_fft = 2048 + # n_mels = 80 + # win_length = 1100 + # hop_length = 275 + # f_min = 40 + # f_max = 11025 + + transforms = torch.nn.Sequential( + InverseSpectralNormalization(), + NormalizeDB(min_level_db=-100, normalization=True), + ) + mel_specgram = transforms(mel_specgram.cpu()) + + with torch.no_grad(): + waveform = wavernn_inference_model(mel_specgram.to(device), + loss_name=wavernn_loss, + mulaw=(not wavernn_no_mulaw), + batched=(not wavernn_no_batch_inference), + timesteps=wavernn_batch_timesteps, + overlap=wavernn_batch_overlap,) + return waveform.unsqueeze(0) + + +def griffin_lim_vocode(mel_specgram, n_fft, n_mels, sample_rate, mel_fmin, mel_fmax, jit, ): + from torchaudio.transforms import GriffinLim, InverseMelScale + + inv_norm = InverseSpectralNormalization() + inv_mel = InverseMelScale( + n_stft=(n_fft // 2 + 1), + n_mels=n_mels, + sample_rate=sample_rate, + f_min=mel_fmin, + f_max=mel_fmax, + mel_scale="slaney", + norm='slaney', + ) + griffin_lim = GriffinLim( + n_fft=n_fft, + power=1, + hop_length=256, + win_length=1024, + ) + + vocoder = torch.nn.Sequential( + inv_norm, + inv_mel, + griffin_lim + ) + + if jit: + vocoder = torch.jit.script(vocoder) + + waveform = vocoder(mel_specgram.cpu()) + return waveform + + +def main(args): + torch.manual_seed(0) + random.seed(0) + np.random.seed(0) + + device = "cuda" if torch.cuda.is_available() else "cpu" + + if args.checkpoint_path is None and args.checkpoint_name is None: + raise ValueError("Either --checkpoint-path or --checkpoint-name must be specified.") + elif args.checkpoint_path is not None and args.checkpoint_name is not None: + raise ValueError("Both --checkpoint-path and --checkpoint-name are specified, " + "can only specify one.") + + n_symbols = len(get_symbol_list(args.text_preprocessor)) + text_preprocessor = partial( + text_to_sequence, + symbol_list=args.text_preprocessor, + phonemizer=args.phonemizer, + checkpoint=args.phonemizer_checkpoint, + cmudict_root=args.cmudict_root, + ) + + if args.checkpoint_path is not None: + tacotron2 = Tacotron2(n_symbol=n_symbols) + tacotron2.load_state_dict( + unwrap_distributed(torch.load(args.checkpoint_path, map_location=device)['state_dict'])) + tacotron2 = tacotron2.to(device).eval() + elif args.checkpoint_name is not None: + tacotron2 = pretrained_tacotron2(args.checkpoint_name).to(device).eval() + + if n_symbols != tacotron2.n_symbols: + raise ValueError("the number of symbols for text_preprocessor ({n_symbols}) " + "should match the number of symbols for the" + "pretrained tacotron2 ({tacotron2.n_symbols}).") + + if args.jit: + tacotron2 = torch.jit.script(tacotron2) + + sequences, lengths = prepare_input_sequence([args.input_text], + text_processor=text_preprocessor) + sequences, lengths = sequences.long().to(device), lengths.long().to(device) + with torch.no_grad(): + mel_specgram, _, _ = tacotron2.infer(sequences, lengths) + + if args.vocoder == "nvidia_waveglow": + waveform = nvidia_waveglow_vocode(mel_specgram=mel_specgram, device=device, jit=args.jit) + + elif args.vocoder == "wavernn": + waveform = wavernn_vocode(mel_specgram=mel_specgram, + wavernn_checkpoint_name=args.wavernn_checkpoint_name, + wavernn_loss=args.wavernn_loss, + wavernn_no_mulaw=args.wavernn_no_mulaw, + wavernn_no_batch_inference=args.wavernn_no_batch_inference, + wavernn_batch_timesteps=args.wavernn_batch_timesteps, + wavernn_batch_overlap=args.wavernn_batch_overlap, + device=device, + jit=args.jit) + + elif args.vocoder == "griffin_lim": + waveform = griffin_lim_vocode(mel_specgram=mel_specgram, + n_fft=args.n_fft, + n_mels=args.n_mels, + sample_rate=args.sample_rate, + mel_fmin=args.mel_fmin, + mel_fmax=args.mel_fmax, + jit=args.jit) + + torchaudio.save(args.output_path, waveform, args.sample_rate) + + +if __name__ == "__main__": + parser = parse_args() + args, _ = parser.parse_known_args() + + main(args) diff --git a/examples/pipeline_tacotron2/loss.py b/examples/pipeline_tacotron2/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..38f4b8bbcfdbd53bedab36cbca7f6c730f8780e4 --- /dev/null +++ b/examples/pipeline_tacotron2/loss.py @@ -0,0 +1,82 @@ +# ***************************************************************************** +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the +# names of its contributors may be used to endorse or promote products +# derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# ***************************************************************************** + +from typing import Tuple + +from torch import nn, Tensor + + +class Tacotron2Loss(nn.Module): + """Tacotron2 loss function modified from: + https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/tacotron2/loss_function.py + """ + + def __init__(self): + super().__init__() + + self.mse_loss = nn.MSELoss(reduction="mean") + self.bce_loss = nn.BCEWithLogitsLoss(reduction="mean") + + def forward( + self, + model_outputs: Tuple[Tensor, Tensor, Tensor], + targets: Tuple[Tensor, Tensor], + ) -> Tuple[Tensor, Tensor, Tensor]: + r"""Pass the input through the Tacotron2 loss. + + The original implementation was introduced in + *Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions* + [:footcite:`shen2018natural`]. + + Args: + model_outputs (tuple of three Tensors): The outputs of the + Tacotron2. These outputs should include three items: + (1) the predicted mel spectrogram before the postnet (``mel_specgram``) + with shape (batch, mel, time). + (2) predicted mel spectrogram after the postnet (``mel_specgram_postnet``) + with shape (batch, mel, time), and + (3) the stop token prediction (``gate_out``) with shape (batch, ). + targets (tuple of two Tensors): The ground truth mel spectrogram (batch, mel, time) and + stop token with shape (batch, ). + + Returns: + mel_loss (Tensor): The mean MSE of the mel_specgram and ground truth mel spectrogram + with shape ``torch.Size([])``. + mel_postnet_loss (Tensor): The mean MSE of the mel_specgram_postnet and + ground truth mel spectrogram with shape ``torch.Size([])``. + gate_loss (Tensor): The mean binary cross entropy loss of + the prediction on the stop token with shape ``torch.Size([])``. + """ + mel_target, gate_target = targets[0], targets[1] + gate_target = gate_target.view(-1, 1) + + mel_specgram, mel_specgram_postnet, gate_out = model_outputs + gate_out = gate_out.view(-1, 1) + mel_loss = self.mse_loss(mel_specgram, mel_target) + mel_postnet_loss = self.mse_loss(mel_specgram_postnet, mel_target) + gate_loss = self.bce_loss(gate_out, gate_target) + return mel_loss, mel_postnet_loss, gate_loss diff --git a/examples/pipeline_tacotron2/text/__init__.py b/examples/pipeline_tacotron2/text/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/pipeline_tacotron2/text/numbers.py b/examples/pipeline_tacotron2/text/numbers.py new file mode 100644 index 0000000000000000000000000000000000000000..d42e82e1349fe28812cb45844648959c3aa6a221 --- /dev/null +++ b/examples/pipeline_tacotron2/text/numbers.py @@ -0,0 +1,116 @@ +# ***************************************************************************** +# Copyright (c) 2017 Keith Ito +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ***************************************************************************** +""" +Modified from https://github.com/keithito/tacotron +""" + +import inflect +import re + + +_inflect = inflect.engine() +_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') +_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') +_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') +_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') +_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') +_number_re = re.compile(r'[0-9]+') + + +def _remove_commas(text: str) -> str: + return re.sub(_comma_number_re, lambda m: m.group(1).replace(',', ''), text) + + +def _expand_pounds(text: str) -> str: + return re.sub(_pounds_re, r'\1 pounds', text) + + +def _expand_dollars_repl_fn(m): + """The replacement function for expanding dollars.""" + match = m.group(1) + parts = match.split('.') + if len(parts) > 2: + return match + ' dollars' # Unexpected format + dollars = int(parts[0]) if parts[0] else 0 + if len(parts) > 1 and parts[1]: + if len(parts[1]) == 1: + # handle the case where we have one digit after the decimal point + cents = int(parts[1]) * 10 + else: + cents = int(parts[1]) + else: + cents = 0 + if dollars and cents: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + return '%s %s' % (dollars, dollar_unit) + elif cents: + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s' % (cents, cent_unit) + else: + return 'zero dollars' + + +def _expand_dollars(text: str) -> str: + return re.sub(_dollars_re, _expand_dollars_repl_fn, text) + + +def _expand_decimal_point(text: str) -> str: + return re.sub(_decimal_number_re, lambda m: m.group(1).replace('.', ' point '), text) + + +def _expand_ordinal(text: str) -> str: + return re.sub(_ordinal_re, lambda m: _inflect.number_to_words(m.group(0)), text) + + +def _expand_number_repl_fn(m): + """The replacement function for expanding number.""" + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return 'two thousand' + elif num > 2000 and num < 2010: + return 'two thousand ' + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + ' hundred' + else: + return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') + else: + return _inflect.number_to_words(num, andword='') + + +def _expand_number(text: str) -> str: + return re.sub(_number_re, _expand_number_repl_fn, text) + + +def normalize_numbers(text: str) -> str: + text = _remove_commas(text) + text = _expand_pounds(text) + text = _expand_dollars(text) + text = _expand_decimal_point(text) + text = _expand_ordinal(text) + text = _expand_number(text) + return text diff --git a/examples/pipeline_tacotron2/text/text_preprocessing.py b/examples/pipeline_tacotron2/text/text_preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..8ca6ae497c392fd892a9442382b4b2b6384ac4bd --- /dev/null +++ b/examples/pipeline_tacotron2/text/text_preprocessing.py @@ -0,0 +1,164 @@ +# ***************************************************************************** +# Copyright (c) 2017 Keith Ito +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ***************************************************************************** +""" +Modified from https://github.com/keithito/tacotron +""" + +from typing import List, Union, Optional +import re + +from unidecode import unidecode +from torchaudio.datasets import CMUDict + +from .numbers import normalize_numbers + + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r'\s+') + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ + ('mrs', 'misess'), + ('mr', 'mister'), + ('dr', 'doctor'), + ('st', 'saint'), + ('co', 'company'), + ('jr', 'junior'), + ('maj', 'major'), + ('gen', 'general'), + ('drs', 'doctors'), + ('rev', 'reverend'), + ('lt', 'lieutenant'), + ('hon', 'honorable'), + ('sgt', 'sergeant'), + ('capt', 'captain'), + ('esq', 'esquire'), + ('ltd', 'limited'), + ('col', 'colonel'), + ('ft', 'fort'), +]] + +_pad = '_' +_punctuation = '!\'(),.:;? ' +_special = '-' +_letters = 'abcdefghijklmnopqrstuvwxyz' + +symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) +_phonemizer = None + + +available_symbol_set = set(["english_characters", "english_phonemes"]) +available_phonemizers = set(["DeepPhonemizer"]) + + +def get_symbol_list(symbol_list: str = "english_characters", + cmudict_root: Optional[str] = "./") -> List[str]: + if symbol_list == "english_characters": + return [_pad] + list(_special) + list(_punctuation) + list(_letters) + elif symbol_list == "english_phonemes": + return [_pad] + list(_special) + list(_punctuation) + CMUDict(cmudict_root).symbols + else: + raise ValueError(f"The `symbol_list` {symbol_list} is not supported." + f"Supported `symbol_list` includes {available_symbol_set}.") + + +def word_to_phonemes(sent: str, phonemizer: str, checkpoint: str) -> List[str]: + if phonemizer == "DeepPhonemizer": + from dp.phonemizer import Phonemizer + global _phonemizer + _other_symbols = ''.join(list(_special) + list(_punctuation)) + _phone_symbols_re = r'(\[[A-Z]+?\]|' + '[' + _other_symbols + '])' # [\[([A-Z]+?)\]|[-!'(),.:;? ]] + + if _phonemizer is None: + # using a global variable so that we don't have to relode checkpoint + # everytime this function is called + _phonemizer = Phonemizer.from_checkpoint(checkpoint) + + # Example: + # sent = "hello world!" + # '[HH][AH][L][OW] [W][ER][L][D]!' + sent = _phonemizer(sent, lang='en_us') + + # ['[HH]', '[AH]', '[L]', '[OW]', ' ', '[W]', '[ER]', '[L]', '[D]', '!'] + ret = re.findall(_phone_symbols_re, sent) + + # ['HH', 'AH', 'L', 'OW', ' ', 'W', 'ER', 'L', 'D', '!'] + ret = [r.replace("[", "").replace("]", "") for r in ret] + + return ret + else: + raise ValueError(f"The `phonemizer` {phonemizer} is not supported. " + "Supported `symbol_list` includes `'DeepPhonemizer'`.") + + +def text_to_sequence(sent: str, + symbol_list: Union[str, List[str]] = "english_characters", + phonemizer: Optional[str] = "DeepPhonemizer", + checkpoint: Optional[str] = "./en_us_cmudict_forward.pt", + cmudict_root: Optional[str] = "./") -> List[int]: + r'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + + Args: + sent (str): The input sentence to convert to a sequence. + symbol_list (str or List of string, optional): When the input is a string, available options include + "english_characters" and "english_phonemes". When the input is a list of string, ``symbol_list`` will + directly be used as the symbol to encode. (Default: "english_characters") + phonemizer (str or None, optional): The phonemizer to use. Only used when ``symbol_list`` is "english_phonemes". + Available options include "DeepPhonemizer". (Default: "DeepPhonemizer") + checkpoint (str or None, optional): The path to the checkpoint of the phonemizer. Only used when + ``symbol_list`` is "english_phonemes". (Default: "./en_us_cmudict_forward.pt") + cmudict_root (str or None, optional): The path to the directory where the CMUDict dataset is found or + downloaded. Only used when ``symbol_list`` is "english_phonemes". (Default: "./") + + Returns: + List of integers corresponding to the symbols in the sentence. + + Examples: + >>> text_to_sequence("hello world!", "english_characters") + [19, 16, 23, 23, 26, 11, 34, 26, 29, 23, 15, 2] + >>> text_to_sequence("hello world!", "english_phonemes") + [54, 20, 65, 69, 11, 92, 44, 65, 38, 2] + ''' + if symbol_list == "english_phonemes": + if any(param is None for param in [phonemizer, checkpoint, cmudict_root]): + raise ValueError( + "When `symbol_list` is 'english_phonemes', " + "all of `phonemizer`, `checkpoint`, and `cmudict_root` must be provided.") + + sent = unidecode(sent) # convert to ascii + sent = sent.lower() # lower case + sent = normalize_numbers(sent) # expand numbers + for regex, replacement in _abbreviations: # expand abbreviations + sent = re.sub(regex, replacement, sent) + sent = re.sub(_whitespace_re, ' ', sent) # collapse whitespace + + if isinstance(symbol_list, list): + symbols = symbol_list + elif isinstance(symbol_list, str): + symbols = get_symbol_list(symbol_list, cmudict_root=cmudict_root) + if symbol_list == "english_phonemes": + sent = word_to_phonemes(sent, phonemizer=phonemizer, checkpoint=checkpoint) + + _symbol_to_id = {s: i for i, s in enumerate(symbols)} + + return [_symbol_to_id[s] for s in sent if s in _symbol_to_id] diff --git a/examples/pipeline_tacotron2/train.py b/examples/pipeline_tacotron2/train.py new file mode 100644 index 0000000000000000000000000000000000000000..4fe93000b8b4e4c9bd25a2221d46cec853abc3cc --- /dev/null +++ b/examples/pipeline_tacotron2/train.py @@ -0,0 +1,528 @@ +# ***************************************************************************** +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the +# names of its contributors may be used to endorse or promote products +# derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# ***************************************************************************** +""" +Modified from +https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/train.py +""" + +import argparse +from datetime import datetime +from functools import partial +import logging +import random +import os +from time import time + +import torch +import torchaudio +import torch.multiprocessing as mp +import torch.distributed as dist +from torch.utils.tensorboard import SummaryWriter +from torch.utils.data import DataLoader +from torch.optim import Adam +from torchaudio.models import Tacotron2 +from tqdm import tqdm +import matplotlib.pyplot as plt +plt.switch_backend('agg') + +from datasets import text_mel_collate_fn, split_process_dataset, SpectralNormalization +from utils import save_checkpoint +from loss import Tacotron2Loss +from text.text_preprocessing import ( + available_symbol_set, + available_phonemizers, + get_symbol_list, + text_to_sequence, +) + + +logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s', + level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S') +logger = logging.getLogger(os.path.basename(__file__)) + + +def parse_args(parser): + """Parse commandline arguments.""" + + parser.add_argument("--dataset", default="ljspeech", choices=["ljspeech"], type=str, + help="select dataset to train with") + parser.add_argument('--logging-dir', type=str, default=None, + help='directory to save the log files') + parser.add_argument('--dataset-path', type=str, default='./', + help='path to dataset') + parser.add_argument("--val-ratio", default=0.1, type=float, + help="the ratio of waveforms for validation") + + parser.add_argument('--anneal-steps', nargs='*', + help='epochs after which decrease learning rate') + parser.add_argument('--anneal-factor', type=float, choices=[0.1, 0.3], default=0.1, + help='factor for annealing learning rate') + + parser.add_argument('--master-addr', default=None, type=str, + help='the address to use for distributed training') + parser.add_argument('--master-port', default=None, type=str, + help='the port to use for distributed training') + + preprocessor = parser.add_argument_group('text preprocessor setup') + preprocessor.add_argument('--text-preprocessor', default='english_characters', type=str, + choices=available_symbol_set, + help='select text preprocessor to use.') + preprocessor.add_argument('--phonemizer', type=str, choices=available_phonemizers, + help='select phonemizer to use, only used when text-preprocessor is "english_phonemes"') + preprocessor.add_argument('--phonemizer-checkpoint', type=str, + help='the path or name of the checkpoint for the phonemizer, ' + 'only used when text-preprocessor is "english_phonemes"') + preprocessor.add_argument('--cmudict-root', default="./", type=str, + help='the root directory for storing cmudictionary files') + + # training + training = parser.add_argument_group('training setup') + training.add_argument('--epochs', type=int, required=True, + help='number of total epochs to run') + training.add_argument('--checkpoint-path', type=str, default='', + help='checkpoint path. If a file exists, ' + 'the program will load it and resume training.') + training.add_argument('--workers', default=8, type=int, + help="number of data loading workers") + training.add_argument("--validate-and-checkpoint-freq", default=10, type=int, metavar="N", + help="validation and saving checkpoint frequency in epochs",) + training.add_argument("--logging-freq", default=10, type=int, metavar="N", + help="logging frequency in epochs") + + optimization = parser.add_argument_group('optimization setup') + optimization.add_argument('--learning-rate', default=1e-3, type=float, + help='initial learing rate') + optimization.add_argument('--weight-decay', default=1e-6, type=float, + help='weight decay') + optimization.add_argument('--batch-size', default=32, type=int, + help='batch size per GPU') + optimization.add_argument('--grad-clip', default=5.0, type=float, + help='clipping gradient with maximum gradient norm value') + + # model parameters + model = parser.add_argument_group('model parameters') + model.add_argument('--mask-padding', action='store_true', default=False, + help='use mask padding') + model.add_argument('--symbols-embedding-dim', default=512, type=int, + help='input embedding dimension') + + # encoder + model.add_argument('--encoder-embedding-dim', default=512, type=int, + help='encoder embedding dimension') + model.add_argument('--encoder-n-convolution', default=3, type=int, + help='number of encoder convolutions') + model.add_argument('--encoder-kernel-size', default=5, type=int, + help='encoder kernel size') + # decoder + model.add_argument('--n-frames-per-step', default=1, type=int, + help='number of frames processed per step (currently only 1 is supported)') + model.add_argument('--decoder-rnn-dim', default=1024, type=int, + help='number of units in decoder LSTM') + model.add_argument('--decoder-dropout', default=0.1, type=float, + help='dropout probability for decoder LSTM') + model.add_argument('--decoder-max-step', default=2000, type=int, + help='maximum number of output mel spectrograms') + model.add_argument('--decoder-no-early-stopping', action='store_true', default=False, + help='stop decoding only when all samples are finished') + + # attention model + model.add_argument('--attention-hidden-dim', default=128, type=int, + help='dimension of attention hidden representation') + model.add_argument('--attention-rnn-dim', default=1024, type=int, + help='number of units in attention LSTM') + model.add_argument('--attention-location-n-filter', default=32, type=int, + help='number of filters for location-sensitive attention') + model.add_argument('--attention-location-kernel-size', default=31, type=int, + help='kernel size for location-sensitive attention') + model.add_argument('--attention-dropout', default=0.1, type=float, + help='dropout probability for attention LSTM') + + model.add_argument('--prenet-dim', default=256, type=int, + help='number of ReLU units in prenet layers') + + # mel-post processing network parameters + model.add_argument('--postnet-n-convolution', default=5, type=float, + help='number of postnet convolutions') + model.add_argument('--postnet-kernel-size', default=5, type=float, + help='postnet kernel size') + model.add_argument('--postnet-embedding-dim', default=512, type=float, + help='postnet embedding dimension') + + model.add_argument('--gate-threshold', default=0.5, type=float, + help='probability threshold for stop token') + + # audio parameters + audio = parser.add_argument_group('audio parameters') + audio.add_argument('--sample-rate', default=22050, type=int, + help='Sampling rate') + audio.add_argument('--n-fft', default=1024, type=int, + help='Filter length for STFT') + audio.add_argument('--hop-length', default=256, type=int, + help='Hop (stride) length') + audio.add_argument('--win-length', default=1024, type=int, + help='Window length') + audio.add_argument('--n-mels', default=80, type=int, + help='') + audio.add_argument('--mel-fmin', default=0.0, type=float, + help='Minimum mel frequency') + audio.add_argument('--mel-fmax', default=8000.0, type=float, + help='Maximum mel frequency') + + return parser + + +def adjust_learning_rate(epoch, optimizer, learning_rate, + anneal_steps, anneal_factor): + """Adjust learning rate base on the initial setting.""" + p = 0 + if anneal_steps is not None: + for _, a_step in enumerate(anneal_steps): + if epoch >= int(a_step): + p = p + 1 + + if anneal_factor == 0.3: + lr = learning_rate * ((0.1 ** (p // 2)) * (1.0 if p % 2 == 0 else 0.3)) + else: + lr = learning_rate * (anneal_factor ** p) + + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def to_gpu(x): + x = x.contiguous() + if torch.cuda.is_available(): + x = x.cuda(non_blocking=True) + return x + + +def batch_to_gpu(batch): + text_padded, text_lengths, mel_specgram_padded, mel_specgram_lengths, gate_padded = batch + text_padded = to_gpu(text_padded).long() + text_lengths = to_gpu(text_lengths).long() + mel_specgram_padded = to_gpu(mel_specgram_padded).float() + gate_padded = to_gpu(gate_padded).float() + mel_specgram_lengths = to_gpu(mel_specgram_lengths).long() + x = (text_padded, text_lengths, mel_specgram_padded, mel_specgram_lengths) + y = (mel_specgram_padded, gate_padded) + return x, y + + +def training_step(model, train_batch, batch_idx): + (text_padded, text_lengths, mel_specgram_padded, mel_specgram_lengths), y = batch_to_gpu(train_batch) + y_pred = model(text_padded, text_lengths, mel_specgram_padded, mel_specgram_lengths) + y[0].requires_grad = False + y[1].requires_grad = False + losses = Tacotron2Loss()(y_pred[:3], y) + return losses[0] + losses[1] + losses[2], losses + + +def validation_step(model, val_batch, batch_idx): + (text_padded, text_lengths, mel_specgram_padded, mel_specgram_lengths), y = batch_to_gpu(val_batch) + y_pred = model(text_padded, text_lengths, mel_specgram_padded, mel_specgram_lengths) + losses = Tacotron2Loss()(y_pred[:3], y) + return losses[0] + losses[1] + losses[2], losses + + +def reduce_tensor(tensor, world_size): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + if rt.is_floating_point(): + rt = rt / world_size + else: + rt = rt // world_size + return rt + + +def log_additional_info(writer, model, loader, epoch): + model.eval() + data = next(iter(loader)) + with torch.no_grad(): + (text_padded, text_lengths, mel_specgram_padded, mel_specgram_lengths), _ = batch_to_gpu(data) + y_pred = model(text_padded, text_lengths, mel_specgram_padded, mel_specgram_lengths) + mel_out, mel_out_postnet, gate_out, alignment = y_pred + + fig = plt.figure() + ax = plt.gca() + ax.imshow(mel_out[0].cpu().numpy()) + writer.add_figure("trn/mel_out", fig, epoch) + fig = plt.figure() + ax = plt.gca() + ax.imshow(mel_out_postnet[0].cpu().numpy()) + writer.add_figure("trn/mel_out_postnet", fig, epoch) + writer.add_image("trn/gate_out", torch.tile(gate_out[:1], (10, 1)), epoch, dataformats="HW") + writer.add_image("trn/alignment", alignment[0], epoch, dataformats="HW") + + +def get_datasets(args): + text_preprocessor = partial( + text_to_sequence, + symbol_list=args.text_preprocessor, + phonemizer=args.phonemizer, + checkpoint=args.phonemizer_checkpoint, + cmudict_root=args.cmudict_root, + ) + + transforms = torch.nn.Sequential( + torchaudio.transforms.MelSpectrogram( + sample_rate=args.sample_rate, + n_fft=args.n_fft, + win_length=args.win_length, + hop_length=args.hop_length, + f_min=args.mel_fmin, + f_max=args.mel_fmax, + n_mels=args.n_mels, + mel_scale='slaney', + normalized=False, + power=1, + norm='slaney', + ), + SpectralNormalization() + ) + trainset, valset = split_process_dataset( + args.dataset, args.dataset_path, args.val_ratio, transforms, text_preprocessor) + return trainset, valset + + +def train(rank, world_size, args): + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + if rank == 0 and args.logging_dir: + if not os.path.isdir(args.logging_dir): + os.makedirs(args.logging_dir) + filehandler = logging.FileHandler(os.path.join(args.logging_dir, 'train.log')) + filehandler.setLevel(logging.INFO) + logger.addHandler(filehandler) + + writer = SummaryWriter(log_dir=args.logging_dir) + else: + writer = None + + torch.manual_seed(0) + + torch.cuda.set_device(rank) + + symbols = get_symbol_list(args.text_preprocessor) + + model = Tacotron2( + mask_padding=args.mask_padding, + n_mels=args.n_mels, + n_symbol=len(symbols), + n_frames_per_step=args.n_frames_per_step, + symbol_embedding_dim=args.symbols_embedding_dim, + encoder_embedding_dim=args.encoder_embedding_dim, + encoder_n_convolution=args.encoder_n_convolution, + encoder_kernel_size=args.encoder_kernel_size, + decoder_rnn_dim=args.decoder_rnn_dim, + decoder_max_step=args.decoder_max_step, + decoder_dropout=args.decoder_dropout, + decoder_early_stopping=(not args.decoder_no_early_stopping), + attention_rnn_dim=args.attention_rnn_dim, + attention_hidden_dim=args.attention_hidden_dim, + attention_location_n_filter=args.attention_location_n_filter, + attention_location_kernel_size=args.attention_location_kernel_size, + attention_dropout=args.attention_dropout, + prenet_dim=args.prenet_dim, + postnet_n_convolution=args.postnet_n_convolution, + postnet_kernel_size=args.postnet_kernel_size, + postnet_embedding_dim=args.postnet_embedding_dim, + gate_threshold=args.gate_threshold, + ).cuda(rank) + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank]) + + optimizer = Adam(model.parameters(), lr=args.learning_rate) + + best_loss = float("inf") + start_epoch = 0 + + if args.checkpoint_path and os.path.isfile(args.checkpoint_path): + logger.info(f"Checkpoint: loading '{args.checkpoint_path}'") + map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} + checkpoint = torch.load(args.checkpoint_path, map_location=map_location) + + start_epoch = checkpoint["epoch"] + best_loss = checkpoint["best_loss"] + + model.load_state_dict(checkpoint["state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + + logger.info( + f"Checkpoint: loaded '{args.checkpoint_path}' at epoch {checkpoint['epoch']}" + ) + + trainset, valset = get_datasets(args) + + train_sampler = torch.utils.data.distributed.DistributedSampler( + trainset, + shuffle=True, + num_replicas=world_size, + rank=rank, + ) + val_sampler = torch.utils.data.distributed.DistributedSampler( + valset, + shuffle=False, + num_replicas=world_size, + rank=rank, + ) + + loader_params = { + "batch_size": args.batch_size, + "num_workers": args.workers, + "prefetch_factor": 1024, + 'persistent_workers': True, + "shuffle": False, + "pin_memory": True, + "drop_last": False, + "collate_fn": partial(text_mel_collate_fn, n_frames_per_step=args.n_frames_per_step), + } + + train_loader = DataLoader(trainset, sampler=train_sampler, **loader_params) + val_loader = DataLoader(valset, sampler=val_sampler, **loader_params) + dist.barrier() + + for epoch in range(start_epoch, args.epochs): + start = time() + + model.train() + trn_loss, counts = 0, 0 + + if rank == 0: + iterator = tqdm(enumerate(train_loader), desc=f"Epoch {epoch}", total=len(train_loader)) + else: + iterator = enumerate(train_loader) + + for i, batch in iterator: + adjust_learning_rate(epoch, optimizer, args.learning_rate, + args.anneal_steps, args.anneal_factor) + + model.zero_grad() + + loss, losses = training_step(model, batch, i) + + loss.backward() + torch.nn.utils.clip_grad_norm_( + model.parameters(), args.grad_clip) + + optimizer.step() + + if rank == 0 and writer: + global_iters = epoch * len(train_loader) + writer.add_scalar("trn/mel_loss", losses[0], global_iters) + writer.add_scalar("trn/mel_postnet_loss", losses[1], global_iters) + writer.add_scalar("trn/gate_loss", losses[2], global_iters) + + trn_loss += loss * len(batch[0]) + counts += len(batch[0]) + + trn_loss = trn_loss / counts + + trn_loss = reduce_tensor(trn_loss, world_size) + if rank == 0: + logger.info(f"[Epoch: {epoch}] time: {time()-start}; trn_loss: {trn_loss}") + if writer: + writer.add_scalar("trn_loss", trn_loss, epoch) + + if ((epoch + 1) % args.validate_and_checkpoint_freq == 0) or (epoch == args.epochs - 1): + + val_start_time = time() + model.eval() + + val_loss, counts = 0, 0 + iterator = tqdm(enumerate(val_loader), desc=f"[Rank: {rank}; Epoch: {epoch}; Eval]", total=len(val_loader)) + + with torch.no_grad(): + for val_batch_idx, val_batch in iterator: + val_loss = val_loss + validation_step(model, val_batch, val_batch_idx)[0] * len(val_batch[0]) + counts = counts + len(val_batch[0]) + val_loss = val_loss / counts + + val_loss = reduce_tensor(val_loss, world_size) + if rank == 0 and writer: + writer.add_scalar("val_loss", val_loss, epoch) + log_additional_info(writer, model, val_loader, epoch) + + if rank == 0: + is_best = val_loss < best_loss + best_loss = min(val_loss, best_loss) + logger.info(f"[Rank: {rank}, Epoch: {epoch}; Eval] time: {time()-val_start_time}; val_loss: {val_loss}") + logger.info(f"[Epoch: {epoch}] Saving checkpoint to {args.checkpoint_path}") + save_checkpoint( + { + "epoch": epoch + 1, + "state_dict": model.state_dict(), + "best_loss": best_loss, + "optimizer": optimizer.state_dict(), + }, + is_best, + args.checkpoint_path, + ) + + dist.destroy_process_group() + + +def main(args): + logger.info("Start time: {}".format(str(datetime.now()))) + + torch.manual_seed(0) + random.seed(0) + + if args.master_addr is not None: + os.environ['MASTER_ADDR'] = args.master_addr + elif 'MASTER_ADDR' not in os.environ: + os.environ['MASTER_ADDR'] = 'localhost' + + if args.master_port is not None: + os.environ['MASTER_PORT'] = args.master_port + elif 'MASTER_PORT' not in os.environ: + os.environ['MASTER_PORT'] = '17778' + + device_counts = torch.cuda.device_count() + + logger.info(f"# available GPUs: {device_counts}") + + # download dataset is not already downloaded + if args.dataset == 'ljspeech': + if not os.path.exists(os.path.join(args.dataset_path, 'LJSpeech-1.1')): + from torchaudio.datasets import LJSPEECH + LJSPEECH(root=args.dataset_path, download=True) + + if device_counts == 1: + train(0, 1, args) + else: + mp.spawn(train, args=(device_counts, args, ), + nprocs=device_counts, join=True) + + logger.info(f"End time: {datetime.now()}") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='PyTorch Tacotron 2 Training') + parser = parse_args(parser) + args, _ = parser.parse_known_args() + + main(args) diff --git a/examples/pipeline_tacotron2/utils.py b/examples/pipeline_tacotron2/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dca668a17ad297d1430e18c0689d64ed94f983a4 --- /dev/null +++ b/examples/pipeline_tacotron2/utils.py @@ -0,0 +1,76 @@ +# ***************************************************************************** +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the +# names of its contributors may be used to endorse or promote products +# derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# ***************************************************************************** + +import logging +import os +import shutil +from typing import List, Tuple, Callable + +import torch +from torch import Tensor + + +def save_checkpoint(state, is_best, filename): + r"""Save the model to a temporary file first, then copy it to filename, + in case signals interrupt the torch.save() process. + """ + torch.save(state, filename) + logging.info(f"Checkpoint saved to {filename}") + + if is_best: + path, best_filename = os.path.split(filename) + best_filename = os.path.join(path, "best_" + best_filename) + shutil.copyfile(filename, best_filename) + logging.info(f"Current best checkpoint saved to {best_filename}") + + +def pad_sequences(batch: List[Tensor]) -> Tuple[Tensor, Tensor]: + r"""Right zero-pad all one-hot text sequences to max input length. + + Modified from https://github.com/NVIDIA/DeepLearningExamples. + """ + input_lengths, ids_sorted_decreasing = torch.sort( + torch.LongTensor([len(x) for x in batch]), dim=0, descending=True) + max_input_len = input_lengths[0] + + text_padded = torch.LongTensor(len(batch), max_input_len) + text_padded.zero_() + for i in range(len(ids_sorted_decreasing)): + text = batch[ids_sorted_decreasing[i]] + text_padded[i, :text.size(0)] = text + + return text_padded, input_lengths + + +def prepare_input_sequence(texts: List[str], + text_processor: Callable[[str], List[int]]) -> Tuple[Tensor, Tensor]: + d = [] + for text in texts: + d.append(torch.IntTensor(text_processor(text)[:])) + + text_padded, input_lengths = pad_sequences(d) + return text_padded, input_lengths diff --git a/examples/pipeline_wav2letter/README.md b/examples/pipeline_wav2letter/README.md new file mode 100644 index 0000000000000000000000000000000000000000..afecf5c204b4821d68bfc733f823d16b9a5842a4 --- /dev/null +++ b/examples/pipeline_wav2letter/README.md @@ -0,0 +1,50 @@ +This is an example pipeline for speech recognition using a greedy or Viterbi CTC decoder, along with the Wav2Letter model trained on LibriSpeech, see [Wav2Letter: an End-to-End ConvNet-based Speech Recognition System](https://arxiv.org/pdf/1609.03193.pdf). Wav2Letter and LibriSpeech are available in torchaudio. + +### Usage + +More information about each command line parameters is available with the `--help` option. An example can be invoked as follows. +```bash +DATASET_ROOT = // +DATASET_FOLDER_IN_ARCHIVE = 'LibriSpeech' + +python main.py \ + --reduce-lr-valid \ + --dataset-root "${DATASET_ROOT}" \ + --dataset-folder-in-archive "${DATASET_FOLDER_IN_ARCHIVE}" \ + --dataset-train train-clean-100 train-clean-360 train-other-500 \ + --dataset-valid dev-clean \ + --batch-size 128 \ + --learning-rate .6 \ + --momentum .8 \ + --weight-decay .00001 \ + --clip-grad 0. \ + --gamma .99 \ + --hop-length 160 \ + --win-length 400 \ + --n-bins 13 \ + --normalize \ + --optimizer adadelta \ + --scheduler reduceonplateau \ + --epochs 40 +``` + +With these default parameters, we get 13.3 %CER and 41.9 %WER on dev-clean after 40 epochs (character and word error rates, respectively) while training on train-clean. The tail of the output is the following. + +```json +... +{"name": "train", "epoch": 40, "batch char error": 925, "batch char total": 22563, "batch char error rate": 0.040996321411159865, "epoch char error": 1135098.0, "epoch char total": 23857713.0, "epoch char error rate": 0.047577821059378154, "batch word error": 791, "batch word total": 4308, "batch word error rate": 0.18361188486536675, "epoch word error": 942906.0, "epoch word total": 4569507.0, "epoch word error rate": 0.20634742435015418, "lr": 0.06, "batch size": 128, "n_channel": 13, "n_time": 1685, "dataset length": 132096.0, "iteration": 1032.0, "loss": 0.07428030669689178, "cumulative loss": 90.47326805442572, "average loss": 0.08766789540157531, "iteration time": 1.9895553588867188, "epoch time": 2036.8874564170837} +{"name": "train", "epoch": 40, "batch char error": 1131, "batch char total": 24260, "batch char error rate": 0.0466199505358615, "epoch char error": 1136229.0, "epoch char total": 23881973.0, "epoch char error rate": 0.04757684802675223, "batch word error": 957, "batch word total": 4657, "batch word error rate": 0.2054971011380717, "epoch word error": 943863.0, "epoch word total": 4574164.0, "epoch word error rate": 0.20634655862798099, "lr": 0.06, "batch size": 128, "n_channel": 13, "n_time": 1641, "dataset length": 132224.0, "iteration": 1033.0, "loss": 0.08775319904088974, "cumulative loss": 90.5610212534666, "average loss": 0.08766797798012256, "iteration time": 2.108018159866333, "epoch time": 2038.99547457695} +{"name": "train", "epoch": 40, "batch char error": 1099, "batch char total": 23526, "batch char error rate": 0.0467142735696676, "epoch char error": 1137328.0, "epoch char total": 23905499.0, "epoch char error rate": 0.04757599914563591, "batch word error": 936, "batch word total": 4544, "batch word error rate": 0.20598591549295775, "epoch word error": 944799.0, "epoch word total": 4578708.0, "epoch word error rate": 0.20634620071863066, "lr": 0.06, "batch size": 128, "n_channel": 13, "n_time": 1682, "dataset length": 132352.0, "iteration": 1034.0, "loss": 0.0791337713599205, "cumulative loss": 90.64015502482653, "average loss": 0.08765972439538348, "iteration time": 2.0329701900482178, "epoch time": 2041.0284447669983} +{"name": "train", "epoch": 40, "batch char error": 1023, "batch char total": 22399, "batch char error rate": 0.045671681771507655, "epoch char error": 1138351.0, "epoch char total": 23927898.0, "epoch char error rate": 0.04757421650660664, "batch word error": 863, "batch word total": 4318, "batch word error rate": 0.1998610467809171, "epoch word error": 945662.0, "epoch word total": 4583026.0, "epoch word error rate": 0.20634009058643787, "lr": 0.06, "batch size": 128, "n_channel": 13, "n_time": 1644, "dataset length": 132480.0, "iteration": 1035.0, "loss": 0.07874362915754318, "cumulative loss": 90.71889865398407, "average loss": 0.08765110981061262, "iteration time": 1.9106628894805908, "epoch time": 2042.9391076564789} +{"name": "validation", "epoch": 40, "cumulative loss": 12.095281183719635, "dataset length": 2688.0, "iteration": 21.0, "batch char error": 1867, "batch char total": 14792, "batch char error rate": 0.12621687398593834, "epoch char error": 37119.0, "epoch char total": 280923.0, "epoch char error rate": 0.13213229247872194, "batch word error": 1155, "batch word total": 2841, "batch word error rate": 0.4065469904963041, "epoch word error": 22601.0, "epoch word total": 54008.0, "epoch word error rate": 0.418475040734706, "average loss": 0.575965770653316, "validation time": 24.185853481292725} +``` +As can be seen in the output above, the information reported at each iteration and epoch (e.g. loss, character error rate, word error rate) is printed to standard output in the form of one json per line. One way to import the output in python with pandas is by saving the standard output to a file, and then using `pandas.read_json(filename, lines=True)`. + +## Structure of pipeline + +* `main.py` -- the entry point +* `ctc_decoders.py` -- the greedy CTC decoder +* `datasets.py` -- the function to split and process librispeech, a collate factory function +* `languagemodels.py` -- a class to encode and decode strings +* `metrics.py` -- the levenshtein edit distance +* `utils.py` -- functions to log metrics, save checkpoint, and count parameters diff --git a/examples/pipeline_wav2letter/ctc_decoders.py b/examples/pipeline_wav2letter/ctc_decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..b4f155d6faf57f9dff34429f9508dff4959c00a5 --- /dev/null +++ b/examples/pipeline_wav2letter/ctc_decoders.py @@ -0,0 +1,15 @@ +from torch import topk + + +class GreedyDecoder: + def __call__(self, outputs): + """Greedy Decoder. Returns highest probability of class labels for each timestep + + Args: + outputs (torch.Tensor): shape (input length, batch size, number of classes (including blank)) + + Returns: + torch.Tensor: class labels per time step. + """ + _, indices = topk(outputs, k=1, dim=-1) + return indices[..., 0] diff --git a/examples/pipeline_wav2letter/datasets.py b/examples/pipeline_wav2letter/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..79b05b2c5bbb40ffc07ba5639228ccb5aad5ad5c --- /dev/null +++ b/examples/pipeline_wav2letter/datasets.py @@ -0,0 +1,113 @@ +import torch +from torchaudio.datasets import LIBRISPEECH + + +class MapMemoryCache(torch.utils.data.Dataset): + """ + Wrap a dataset so that, whenever a new item is returned, it is saved to memory. + """ + + def __init__(self, dataset): + self.dataset = dataset + self._cache = [None] * len(dataset) + + def __getitem__(self, n): + if self._cache[n] is not None: + return self._cache[n] + + item = self.dataset[n] + self._cache[n] = item + + return item + + def __len__(self): + return len(self.dataset) + + +class Processed(torch.utils.data.Dataset): + def __init__(self, dataset, transforms, encode): + self.dataset = dataset + self.transforms = transforms + self.encode = encode + + def __getitem__(self, key): + item = self.dataset[key] + return self.process_datapoint(item) + + def __len__(self): + return len(self.dataset) + + def process_datapoint(self, item): + transformed = item[0] + target = item[2].lower() + + transformed = self.transforms(transformed) + transformed = transformed[0, ...].transpose(0, -1) + + target = self.encode(target) + target = torch.tensor(target, dtype=torch.long, device=transformed.device) + + return transformed, target + + +def split_process_librispeech( + datasets, transforms, language_model, root, folder_in_archive, +): + def create(tags, cache=True): + + if isinstance(tags, str): + tags = [tags] + if isinstance(transforms, list): + transform_list = transforms + else: + transform_list = [transforms] + + data = torch.utils.data.ConcatDataset( + [ + Processed( + LIBRISPEECH( + root, tag, folder_in_archive=folder_in_archive, download=False, + ), + transform, + language_model.encode, + ) + for tag, transform in zip(tags, transform_list) + ] + ) + + data = MapMemoryCache(data) + return data + + # For performance, we cache all datasets + return tuple(create(dataset) for dataset in datasets) + + +def collate_factory(model_length_function, transforms=None): + + if transforms is None: + transforms = torch.nn.Sequential() + + def collate_fn(batch): + + tensors = [transforms(b[0]) for b in batch if b] + + tensors_lengths = torch.tensor( + [model_length_function(t) for t in tensors], + dtype=torch.long, + device=tensors[0].device, + ) + + tensors = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True) + tensors = tensors.transpose(1, -1) + + targets = [b[1] for b in batch if b] + target_lengths = torch.tensor( + [target.shape[0] for target in targets], + dtype=torch.long, + device=tensors.device, + ) + targets = torch.nn.utils.rnn.pad_sequence(targets, batch_first=True) + + return tensors, targets, tensors_lengths, target_lengths + + return collate_fn diff --git a/examples/pipeline_wav2letter/languagemodels.py b/examples/pipeline_wav2letter/languagemodels.py new file mode 100644 index 0000000000000000000000000000000000000000..d66858ea9546675eef68f37ba0ccf73124852e9f --- /dev/null +++ b/examples/pipeline_wav2letter/languagemodels.py @@ -0,0 +1,38 @@ +import collections +import itertools + + +class LanguageModel: + def __init__(self, labels, char_blank, char_space): + + self.char_space = char_space + self.char_blank = char_blank + + labels = list(labels) + self.length = len(labels) + enumerated = list(enumerate(labels)) + flipped = [(sub[1], sub[0]) for sub in enumerated] + + d1 = collections.OrderedDict(enumerated) + d2 = collections.OrderedDict(flipped) + self.mapping = {**d1, **d2} + + def encode(self, iterable): + if isinstance(iterable, list): + return [self.encode(i) for i in iterable] + else: + return [self.mapping[i] + self.mapping[self.char_blank] for i in iterable] + + def decode(self, tensor): + if len(tensor) > 0 and isinstance(tensor[0], list): + return [self.decode(t) for t in tensor] + else: + # not idempotent, since clean string + x = (self.mapping[i] for i in tensor) + x = "".join(i for i, _ in itertools.groupby(x)) + x = x.replace(self.char_blank, "") + # x = x.strip() + return x + + def __len__(self): + return self.length diff --git a/examples/pipeline_wav2letter/main.py b/examples/pipeline_wav2letter/main.py new file mode 100644 index 0000000000000000000000000000000000000000..1668223067f408c5d1041aa92ff64582472a7019 --- /dev/null +++ b/examples/pipeline_wav2letter/main.py @@ -0,0 +1,663 @@ +import argparse +import logging +import os +import string +from datetime import datetime +from time import time + +import torch +import torchaudio +from torch.optim import SGD, Adadelta, Adam, AdamW +from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau +from torch.utils.data import DataLoader +from torchaudio.datasets.utils import bg_iterator +from torchaudio.functional import edit_distance +from torchaudio.models.wav2letter import Wav2Letter + +from ctc_decoders import GreedyDecoder +from datasets import collate_factory, split_process_librispeech +from languagemodels import LanguageModel +from transforms import Normalize, UnsqueezeFirst +from utils import MetricLogger, count_parameters, save_checkpoint + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--type", + metavar="T", + default="mfcc", + choices=["waveform", "mfcc"], + help="input type for model", + ) + parser.add_argument( + "--freq-mask", + default=0, + type=int, + metavar="N", + help="maximal width of frequency mask", + ) + parser.add_argument( + "--win-length", + default=400, + type=int, + metavar="N", + help="width of spectrogram window", + ) + parser.add_argument( + "--hop-length", + default=160, + type=int, + metavar="N", + help="width of spectrogram window", + ) + parser.add_argument( + "--time-mask", + default=0, + type=int, + metavar="N", + help="maximal width of time mask", + ) + parser.add_argument( + "--workers", + default=0, + type=int, + metavar="N", + help="number of data loading workers", + ) + parser.add_argument( + "--checkpoint", + default="", + type=str, + metavar="PATH", + help="path to latest checkpoint", + ) + parser.add_argument( + "--epochs", + default=200, + type=int, + metavar="N", + help="number of total epochs to run", + ) + parser.add_argument( + "--start-epoch", default=0, type=int, metavar="N", help="manual epoch number" + ) + parser.add_argument( + "--reduce-lr-valid", + action="store_true", + help="reduce learning rate based on validation loss", + ) + parser.add_argument( + "--normalize", action="store_true", help="normalize model input" + ) + parser.add_argument( + "--progress-bar", action="store_true", help="use progress bar while training" + ) + parser.add_argument( + "--decoder", + metavar="D", + default="greedy", + choices=["greedy"], + help="decoder to use", + ) + parser.add_argument( + "--batch-size", default=128, type=int, metavar="N", help="mini-batch size" + ) + parser.add_argument( + "--n-bins", + default=13, + type=int, + metavar="N", + help="number of bins in transforms", + ) + parser.add_argument( + "--optimizer", + metavar="OPT", + default="adadelta", + choices=["sgd", "adadelta", "adam", "adamw"], + help="optimizer to use", + ) + parser.add_argument( + "--scheduler", + metavar="S", + default="reduceonplateau", + choices=["exponential", "reduceonplateau"], + help="optimizer to use", + ) + parser.add_argument( + "--learning-rate", + default=0.6, + type=float, + metavar="LR", + help="initial learning rate", + ) + parser.add_argument( + "--gamma", + default=0.99, + type=float, + metavar="GAMMA", + help="learning rate exponential decay constant", + ) + parser.add_argument( + "--momentum", default=0.8, type=float, metavar="M", help="momentum" + ) + parser.add_argument( + "--weight-decay", default=1e-5, type=float, metavar="W", help="weight decay" + ) + parser.add_argument("--eps", metavar="EPS", type=float, default=1e-8) + parser.add_argument("--rho", metavar="RHO", type=float, default=0.95) + parser.add_argument("--clip-grad", metavar="NORM", type=float, default=0.0) + parser.add_argument( + "--dataset-root", + type=str, + help="specify dataset root folder", + ) + parser.add_argument( + "--dataset-folder-in-archive", + type=str, + help="specify dataset folder in archive", + ) + parser.add_argument( + "--dataset-train", + default=["train-clean-100"], + nargs="+", + type=str, + help="select which part of librispeech to train with", + ) + parser.add_argument( + "--dataset-valid", + default=["dev-clean"], + nargs="+", + type=str, + help="select which part of librispeech to validate with", + ) + parser.add_argument( + "--distributed", action="store_true", help="enable DistributedDataParallel" + ) + parser.add_argument("--seed", type=int, default=0, help="random seed") + parser.add_argument( + "--world-size", type=int, default=8, help="the world size to initiate DPP" + ) + parser.add_argument("--jit", action="store_true", help="if used, model is jitted") + + args = parser.parse_args() + logging.info(args) + return args + + +def setup_distributed(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + # initialize the process group + torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size) + + +def model_length_function(tensor): + if tensor.shape[1] == 1: + # waveform mode + return int(tensor.shape[0]) // 160 // 2 + 1 + return int(tensor.shape[0]) // 2 + 1 + + +def compute_error_rates(outputs, targets, decoder, language_model, metric): + output = outputs.transpose(0, 1).to("cpu") + output = decoder(output) + + # Compute CER + + output = language_model.decode(output.tolist()) + target = language_model.decode(targets.tolist()) + + print_length = 20 + for i in range(2): + # Print a few examples + output_print = output[i].ljust(print_length)[:print_length] + target_print = target[i].ljust(print_length)[:print_length] + logging.info("Target: %s Output: %s", target_print, output_print) + + cers = [edit_distance(t, o) for t, o in zip(target, output)] + cers = sum(cers) + n = sum(len(t) for t in target) + metric["batch char error"] = cers + metric["batch char total"] = n + metric["batch char error rate"] = cers / n + metric["epoch char error"] += cers + metric["epoch char total"] += n + metric["epoch char error rate"] = metric["epoch char error"] / metric["epoch char total"] + + # Compute WER + + output = [o.split(language_model.char_space) for o in output] + target = [t.split(language_model.char_space) for t in target] + + wers = [edit_distance(t, o) for t, o in zip(target, output)] + wers = sum(wers) + n = sum(len(t) for t in target) + metric["batch word error"] = wers + metric["batch word total"] = n + metric["batch word error rate"] = wers / n + metric["epoch word error"] += wers + metric["epoch word total"] += n + metric["epoch word error rate"] = metric["epoch word error"] / metric["epoch word total"] + + +def train_one_epoch( + model, + criterion, + optimizer, + scheduler, + data_loader, + decoder, + language_model, + device, + epoch, + clip_grad, + disable_logger=False, + reduce_lr_on_plateau=False, +): + + model.train() + + metric = MetricLogger("train", disable=disable_logger) + metric["epoch"] = epoch + + for inputs, targets, tensors_lengths, target_lengths in bg_iterator( + data_loader, maxsize=2 + ): + + start = time() + inputs = inputs.to(device, non_blocking=True) + targets = targets.to(device, non_blocking=True) + + # keep batch first for data parallel + outputs = model(inputs).transpose(-1, -2).transpose(0, 1) + + # CTC + # outputs: input length, batch size, number of classes (including blank) + # targets: batch size, max target length + # input_lengths: batch size + # target_lengths: batch size + + loss = criterion(outputs, targets, tensors_lengths, target_lengths) + + optimizer.zero_grad() + loss.backward() + + if clip_grad > 0: + metric["gradient"] = torch.nn.utils.clip_grad_norm_( + model.parameters(), clip_grad + ) + + optimizer.step() + + compute_error_rates(outputs, targets, decoder, language_model, metric) + + try: + metric["lr"] = scheduler.get_last_lr()[0] + except AttributeError: + metric["lr"] = optimizer.param_groups[0]["lr"] + + metric["batch size"] = len(inputs) + metric["n_channel"] = inputs.shape[1] + metric["n_time"] = inputs.shape[-1] + metric["dataset length"] += metric["batch size"] + metric["iteration"] += 1 + metric["loss"] = loss.item() + metric["cumulative loss"] += metric["loss"] + metric["average loss"] = metric["cumulative loss"] / metric["iteration"] + metric["iteration time"] = time() - start + metric["epoch time"] += metric["iteration time"] + metric() + + if reduce_lr_on_plateau and isinstance(scheduler, ReduceLROnPlateau): + scheduler.step(metric["average loss"]) + elif not isinstance(scheduler, ReduceLROnPlateau): + scheduler.step() + + +def evaluate( + model, + criterion, + data_loader, + decoder, + language_model, + device, + epoch, + disable_logger=False, +): + + with torch.no_grad(): + + model.eval() + start = time() + metric = MetricLogger("validation", disable=disable_logger) + metric["epoch"] = epoch + + for inputs, targets, tensors_lengths, target_lengths in bg_iterator( + data_loader, maxsize=2 + ): + + inputs = inputs.to(device, non_blocking=True) + targets = targets.to(device, non_blocking=True) + + # keep batch first for data parallel + outputs = model(inputs).transpose(-1, -2).transpose(0, 1) + + # CTC + # outputs: input length, batch size, number of classes (including blank) + # targets: batch size, max target length + # input_lengths: batch size + # target_lengths: batch size + + metric["cumulative loss"] += criterion( + outputs, targets, tensors_lengths, target_lengths + ).item() + + metric["dataset length"] += len(inputs) + metric["iteration"] += 1 + + compute_error_rates(outputs, targets, decoder, language_model, metric) + + metric["average loss"] = metric["cumulative loss"] / metric["iteration"] + metric["validation time"] = time() - start + metric() + + return metric["average loss"] + + +def main(rank, args): + + # Distributed setup + + if args.distributed: + setup_distributed(rank, args.world_size) + + not_main_rank = args.distributed and rank != 0 + + logging.info("Start time: %s", datetime.now()) + + # Explicitly set seed to make sure models created in separate processes + # start from same random weights and biases + torch.manual_seed(args.seed) + + # Empty CUDA cache + torch.cuda.empty_cache() + + # Change backend for flac files + torchaudio.set_audio_backend("soundfile") + + # Transforms + + melkwargs = { + "n_fft": args.win_length, + "n_mels": args.n_bins, + "hop_length": args.hop_length, + } + + sample_rate_original = 16000 + + if args.type == "mfcc": + transforms = torch.nn.Sequential( + torchaudio.transforms.MFCC( + sample_rate=sample_rate_original, + n_mfcc=args.n_bins, + melkwargs=melkwargs, + ), + ) + num_features = args.n_bins + elif args.type == "waveform": + transforms = torch.nn.Sequential(UnsqueezeFirst()) + num_features = 1 + else: + raise ValueError("Model type not supported") + + if args.normalize: + transforms = torch.nn.Sequential(transforms, Normalize()) + + augmentations = torch.nn.Sequential() + if args.freq_mask: + augmentations = torch.nn.Sequential( + augmentations, + torchaudio.transforms.FrequencyMasking(freq_mask_param=args.freq_mask), + ) + if args.time_mask: + augmentations = torch.nn.Sequential( + augmentations, + torchaudio.transforms.TimeMasking(time_mask_param=args.time_mask), + ) + + # Text preprocessing + + char_blank = "*" + char_space = " " + char_apostrophe = "'" + labels = char_blank + char_space + char_apostrophe + string.ascii_lowercase + language_model = LanguageModel(labels, char_blank, char_space) + + # Dataset + + training, validation = split_process_librispeech( + [args.dataset_train, args.dataset_valid], + [transforms, transforms], + language_model, + root=args.dataset_root, + folder_in_archive=args.dataset_folder_in_archive, + ) + + # Decoder + + if args.decoder == "greedy": + decoder = GreedyDecoder() + else: + raise ValueError("Selected decoder not supported") + + # Model + + model = Wav2Letter( + num_classes=language_model.length, + input_type=args.type, + num_features=num_features, + ) + + if args.jit: + model = torch.jit.script(model) + + if args.distributed: + n = torch.cuda.device_count() // args.world_size + devices = list(range(rank * n, (rank + 1) * n)) + model = model.to(devices[0]) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=devices) + else: + devices = ["cuda" if torch.cuda.is_available() else "cpu"] + model = model.to(devices[0], non_blocking=True) + model = torch.nn.DataParallel(model) + + n = count_parameters(model) + logging.info("Number of parameters: %s", n) + + # Optimizer + + if args.optimizer == "adadelta": + optimizer = Adadelta( + model.parameters(), + lr=args.learning_rate, + weight_decay=args.weight_decay, + eps=args.eps, + rho=args.rho, + ) + elif args.optimizer == "sgd": + optimizer = SGD( + model.parameters(), + lr=args.learning_rate, + momentum=args.momentum, + weight_decay=args.weight_decay, + ) + elif args.optimizer == "adam": + optimizer = Adam( + model.parameters(), + lr=args.learning_rate, + momentum=args.momentum, + weight_decay=args.weight_decay, + ) + elif args.optimizer == "adamw": + optimizer = AdamW( + model.parameters(), + lr=args.learning_rate, + momentum=args.momentum, + weight_decay=args.weight_decay, + ) + else: + raise ValueError("Selected optimizer not supported") + + if args.scheduler == "exponential": + scheduler = ExponentialLR(optimizer, gamma=args.gamma) + elif args.scheduler == "reduceonplateau": + scheduler = ReduceLROnPlateau(optimizer, patience=10, threshold=1e-3) + else: + raise ValueError("Selected scheduler not supported") + + criterion = torch.nn.CTCLoss( + blank=language_model.mapping[char_blank], zero_infinity=False + ) + + # Data Loader + + collate_fn_train = collate_factory(model_length_function, augmentations) + collate_fn_valid = collate_factory(model_length_function) + + loader_training_params = { + "num_workers": args.workers, + "pin_memory": True, + "shuffle": True, + "drop_last": True, + } + loader_validation_params = loader_training_params.copy() + loader_validation_params["shuffle"] = False + + loader_training = DataLoader( + training, + batch_size=args.batch_size, + collate_fn=collate_fn_train, + **loader_training_params, + ) + loader_validation = DataLoader( + validation, + batch_size=args.batch_size, + collate_fn=collate_fn_valid, + **loader_validation_params, + ) + + # Setup checkpoint + + best_loss = 1.0 + + load_checkpoint = args.checkpoint and os.path.isfile(args.checkpoint) + + if args.distributed: + torch.distributed.barrier() + + if load_checkpoint: + logging.info("Checkpoint: loading %s", args.checkpoint) + checkpoint = torch.load(args.checkpoint) + + args.start_epoch = checkpoint["epoch"] + best_loss = checkpoint["best_loss"] + + model.load_state_dict(checkpoint["state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + scheduler.load_state_dict(checkpoint["scheduler"]) + + logging.info( + "Checkpoint: loaded '%s' at epoch %s", args.checkpoint, checkpoint["epoch"] + ) + else: + logging.info("Checkpoint: not found") + + save_checkpoint( + { + "epoch": args.start_epoch, + "state_dict": model.state_dict(), + "best_loss": best_loss, + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + }, + False, + args.checkpoint, + not_main_rank, + ) + + if args.distributed: + torch.distributed.barrier() + + torch.autograd.set_detect_anomaly(False) + + for epoch in range(args.start_epoch, args.epochs): + + logging.info("Epoch: %s", epoch) + + train_one_epoch( + model, + criterion, + optimizer, + scheduler, + loader_training, + decoder, + language_model, + devices[0], + epoch, + args.clip_grad, + not_main_rank, + not args.reduce_lr_valid, + ) + + loss = evaluate( + model, + criterion, + loader_validation, + decoder, + language_model, + devices[0], + epoch, + not_main_rank, + ) + + if args.reduce_lr_valid and isinstance(scheduler, ReduceLROnPlateau): + scheduler.step(loss) + + is_best = loss < best_loss + best_loss = min(loss, best_loss) + save_checkpoint( + { + "epoch": epoch + 1, + "state_dict": model.state_dict(), + "best_loss": best_loss, + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + }, + is_best, + args.checkpoint, + not_main_rank, + ) + + logging.info("End time: %s", datetime.now()) + + if args.distributed: + torch.distributed.destroy_process_group() + + +def spawn_main(main, args): + if args.distributed: + torch.multiprocessing.spawn( + main, args=(args,), nprocs=args.world_size, join=True + ) + else: + main(0, args) + + +if __name__ == "__main__": + + logging.basicConfig(level=logging.INFO) + args = parse_args() + spawn_main(main, args) diff --git a/examples/pipeline_wav2letter/transforms.py b/examples/pipeline_wav2letter/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..f1d9115c87315a494a3edad31e00674d0710da37 --- /dev/null +++ b/examples/pipeline_wav2letter/transforms.py @@ -0,0 +1,11 @@ +import torch + + +class Normalize(torch.nn.Module): + def forward(self, tensor): + return (tensor - tensor.mean(-1, keepdim=True)) / tensor.std(-1, keepdim=True) + + +class UnsqueezeFirst(torch.nn.Module): + def forward(self, tensor): + return tensor.unsqueeze(0) diff --git a/examples/pipeline_wav2letter/utils.py b/examples/pipeline_wav2letter/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7cd07a2a80002661ae12f44c68deeec73f05a38e --- /dev/null +++ b/examples/pipeline_wav2letter/utils.py @@ -0,0 +1,55 @@ +import json +import logging +import os +import shutil +from collections import defaultdict + +import torch + + +class MetricLogger(defaultdict): + def __init__(self, name, print_freq=1, disable=False): + super().__init__(lambda: 0.0) + self.disable = disable + self.print_freq = print_freq + self._iter = 0 + self["name"] = name + + def __str__(self): + return json.dumps(self) + + def __call__(self): + self._iter = (self._iter + 1) % self.print_freq + if not self.disable and not self._iter: + print(self, flush=True) + + +def save_checkpoint(state, is_best, filename, disable): + """ + Save the model to a temporary file first, + then copy it to filename, in case the signal interrupts + the torch.save() process. + """ + + if disable: + return + + if filename == "": + return + + tempfile = filename + ".temp" + + # Remove tempfile in case interuption during the copying from tempfile to filename + if os.path.isfile(tempfile): + os.remove(tempfile) + + torch.save(state, tempfile) + if os.path.isfile(tempfile): + os.rename(tempfile, filename) + if is_best: + shutil.copyfile(filename, "model_best.pth.tar") + logging.warning("Checkpoint: saved") + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) diff --git a/examples/pipeline_wavernn/README.md b/examples/pipeline_wavernn/README.md new file mode 100644 index 0000000000000000000000000000000000000000..170f4f89e7a291f988a33def9b5f8572c124e28f --- /dev/null +++ b/examples/pipeline_wavernn/README.md @@ -0,0 +1,47 @@ +This is an example vocoder pipeline using the WaveRNN model trained with LJSpeech. WaveRNN model is based on the implementation from [this repository](https://github.com/fatchord/WaveRNN). The original implementation was +introduced in "Efficient Neural Audio Synthesis". WaveRNN and LJSpeech are available in torchaudio. + +### Usage + +An example can be invoked as follows. +``` +python main.py \ + --batch-size 256 \ + --learning-rate 1e-4 \ + --n-freq 80 \ + --loss 'crossentropy' \ + --n-bits 8 \ +``` + +For inference, an example can be invoked as follows. +Please refer to the [documentation](https://pytorch.org/audio/master/models.html#id10) for +available checkpoints. +``` +python inference.py \ + --checkpoint-name wavernn_10k_epochs_8bits_ljspeech \ + --output-wav-path ./output.wav +``` + +This example would generate a file named `output.wav` in the current working directory. + +### Output + +The information reported at each iteration and epoch (e.g. loss) is printed to standard output in the form of one json per line. Here is an example python function to parse the output if redirected to a file. +```python +def read_json(filename): + """ + Convert the standard output saved to filename into a pandas dataframe for analysis. + """ + + import pandas + import json + + with open(filename, "r") as f: + data = f.read() + + # pandas doesn't read single quotes for json + data = data.replace("'", '"') + + data = [json.loads(l) for l in data.splitlines()] + return pandas.DataFrame(data) +``` diff --git a/examples/pipeline_wavernn/datasets.py b/examples/pipeline_wavernn/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..ba836690b484ab848853776df30b99c264626723 --- /dev/null +++ b/examples/pipeline_wavernn/datasets.py @@ -0,0 +1,121 @@ +import random + +import torch +from torch.utils.data.dataset import random_split +from torchaudio.datasets import LJSPEECH, LIBRITTS +from torchaudio.transforms import MuLawEncoding + +from processing import bits_to_normalized_waveform, normalized_waveform_to_bits + + +class MapMemoryCache(torch.utils.data.Dataset): + r"""Wrap a dataset so that, whenever a new item is returned, it is saved to memory. + """ + + def __init__(self, dataset): + self.dataset = dataset + self._cache = [None] * len(dataset) + + def __getitem__(self, n): + if self._cache[n] is not None: + return self._cache[n] + + item = self.dataset[n] + self._cache[n] = item + + return item + + def __len__(self): + return len(self.dataset) + + +class Processed(torch.utils.data.Dataset): + def __init__(self, dataset, transforms): + self.dataset = dataset + self.transforms = transforms + + def __getitem__(self, key): + item = self.dataset[key] + return self.process_datapoint(item) + + def __len__(self): + return len(self.dataset) + + def process_datapoint(self, item): + specgram = self.transforms(item[0]) + return item[0].squeeze(0), specgram + + +def split_process_dataset(args, transforms): + if args.dataset == 'ljspeech': + data = LJSPEECH(root=args.file_path, download=False) + + val_length = int(len(data) * args.val_ratio) + lengths = [len(data) - val_length, val_length] + train_dataset, val_dataset = random_split(data, lengths) + + elif args.dataset == 'libritts': + train_dataset = LIBRITTS(root=args.file_path, url='train-clean-100', download=False) + val_dataset = LIBRITTS(root=args.file_path, url='dev-clean', download=False) + + else: + raise ValueError(f"Expected dataset: `ljspeech` or `libritts`, but found {args.dataset}") + + train_dataset = Processed(train_dataset, transforms) + val_dataset = Processed(val_dataset, transforms) + + train_dataset = MapMemoryCache(train_dataset) + val_dataset = MapMemoryCache(val_dataset) + + return train_dataset, val_dataset + + +def collate_factory(args): + def raw_collate(batch): + + pad = (args.kernel_size - 1) // 2 + + # input waveform length + wave_length = args.hop_length * args.seq_len_factor + # input spectrogram length + spec_length = args.seq_len_factor + pad * 2 + + # max start postion in spectrogram + max_offsets = [x[1].shape[-1] - (spec_length + pad * 2) for x in batch] + + # random start postion in spectrogram + spec_offsets = [random.randint(0, offset) for offset in max_offsets] + # random start postion in waveform + wave_offsets = [(offset + pad) * args.hop_length for offset in spec_offsets] + + waveform_combine = [ + x[0][wave_offsets[i]: wave_offsets[i] + wave_length + 1] + for i, x in enumerate(batch) + ] + specgram = [ + x[1][:, spec_offsets[i]: spec_offsets[i] + spec_length] + for i, x in enumerate(batch) + ] + + specgram = torch.stack(specgram) + waveform_combine = torch.stack(waveform_combine) + + waveform = waveform_combine[:, :wave_length] + target = waveform_combine[:, 1:] + + # waveform: [-1, 1], target: [0, 2**bits-1] if loss = 'crossentropy' + if args.loss == "crossentropy": + + if args.mulaw: + mulaw_encode = MuLawEncoding(2 ** args.n_bits) + waveform = mulaw_encode(waveform) + target = mulaw_encode(target) + + waveform = bits_to_normalized_waveform(waveform, args.n_bits) + + else: + target = normalized_waveform_to_bits(target, args.n_bits) + + return waveform.unsqueeze(1), specgram.unsqueeze(1), target.unsqueeze(1) + + return raw_collate diff --git a/examples/pipeline_wavernn/inference.py b/examples/pipeline_wavernn/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..08b608a4f90ce7934b8b6d3485ad3188301d915d --- /dev/null +++ b/examples/pipeline_wavernn/inference.py @@ -0,0 +1,88 @@ +import argparse + +import torch +import torchaudio +from torchaudio.transforms import MelSpectrogram +from torchaudio.models import wavernn +from torchaudio.models.wavernn import _MODEL_CONFIG_AND_URLS +from torchaudio.datasets import LJSPEECH + +from wavernn_inference_wrapper import WaveRNNInferenceWrapper +from processing import NormalizeDB + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--output-wav-path", default="./output.wav", type=str, metavar="PATH", + help="The path to output the reconstructed wav file.", + ) + parser.add_argument( + "--jit", default=False, action="store_true", + help="If used, the model and inference function is jitted." + ) + parser.add_argument( + "--no-batch-inference", default=False, action="store_true", + help="Don't use batch inference." + ) + parser.add_argument( + "--no-mulaw", default=False, action="store_true", + help="Don't use mulaw decoder to decoder the signal." + ) + parser.add_argument( + "--checkpoint-name", default="wavernn_10k_epochs_8bits_ljspeech", + choices=list(_MODEL_CONFIG_AND_URLS.keys()), + help="Select the WaveRNN checkpoint." + ) + parser.add_argument( + "--batch-timesteps", default=100, type=int, + help="The time steps for each batch. Only used when batch inference is used", + ) + parser.add_argument( + "--batch-overlap", default=5, type=int, + help="The overlapping time steps between batches. Only used when batch inference is used", + ) + args = parser.parse_args() + return args + + +def main(args): + device = "cuda" if torch.cuda.is_available() else "cpu" + waveform, sample_rate, _, _ = LJSPEECH("./", download=True)[0] + + mel_kwargs = { + 'sample_rate': sample_rate, + 'n_fft': 2048, + 'f_min': 40., + 'n_mels': 80, + 'win_length': 1100, + 'hop_length': 275, + 'mel_scale': 'slaney', + 'norm': 'slaney', + 'power': 1, + } + transforms = torch.nn.Sequential( + MelSpectrogram(**mel_kwargs), + NormalizeDB(min_level_db=-100, normalization=True), + ) + mel_specgram = transforms(waveform) + + wavernn_model = wavernn(args.checkpoint_name).eval().to(device) + wavernn_inference_model = WaveRNNInferenceWrapper(wavernn_model) + + if args.jit: + wavernn_inference_model = torch.jit.script(wavernn_inference_model) + + with torch.no_grad(): + output = wavernn_inference_model(mel_specgram.to(device), + mulaw=(not args.no_mulaw), + batched=(not args.no_batch_inference), + timesteps=args.batch_timesteps, + overlap=args.batch_overlap,) + + torchaudio.save(args.output_wav_path, output, sample_rate=sample_rate) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/pipeline_wavernn/losses.py b/examples/pipeline_wavernn/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..a4494b05fbda565af261e0a1877d7705ea134d98 --- /dev/null +++ b/examples/pipeline_wavernn/losses.py @@ -0,0 +1,119 @@ +import math + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +class LongCrossEntropyLoss(nn.Module): + r""" CrossEntropy loss + """ + + def __init__(self): + super(LongCrossEntropyLoss, self).__init__() + + def forward(self, output, target): + output = output.transpose(1, 2) + target = target.long() + + criterion = nn.CrossEntropyLoss() + return criterion(output, target) + + +class MoLLoss(nn.Module): + r""" Discretized mixture of logistic distributions loss + + Adapted from wavenet vocoder + (https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py) + Explanation of loss (https://github.com/Rayhane-mamah/Tacotron-2/issues/155) + + Args: + y_hat (Tensor): Predicted output (n_batch x n_time x n_channel) + y (Tensor): Target (n_batch x n_time x 1) + num_classes (int): Number of classes + log_scale_min (float): Log scale minimum value + reduce (bool): If True, the losses are averaged or summed for each minibatch + + Returns + Tensor: loss + """ + + def __init__(self, num_classes=65536, log_scale_min=None, reduce=True): + super(MoLLoss, self).__init__() + self.num_classes = num_classes + self.log_scale_min = log_scale_min + self.reduce = reduce + + def forward(self, y_hat, y): + y = y.unsqueeze(-1) + + if self.log_scale_min is None: + self.log_scale_min = math.log(1e-14) + + assert y_hat.dim() == 3 + assert y_hat.size(-1) % 3 == 0 + + nr_mix = y_hat.size(-1) // 3 + + # unpack parameters (n_batch, n_time, num_mixtures) x 3 + logit_probs = y_hat[:, :, :nr_mix] + means = y_hat[:, :, nr_mix: 2 * nr_mix] + log_scales = torch.clamp( + y_hat[:, :, 2 * nr_mix: 3 * nr_mix], min=self.log_scale_min + ) + + # (n_batch x n_time x 1) to (n_batch x n_time x num_mixtures) + y = y.expand_as(means) + + centered_y = y - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_y + 1.0 / (self.num_classes - 1)) + cdf_plus = torch.sigmoid(plus_in) + min_in = inv_stdv * (centered_y - 1.0 / (self.num_classes - 1)) + cdf_min = torch.sigmoid(min_in) + + # log probability for edge case of 0 (before scaling) + # equivalent: torch.log(F.sigmoid(plus_in)) + log_cdf_plus = plus_in - F.softplus(plus_in) + + # log probability for edge case of 255 (before scaling) + # equivalent: (1 - F.sigmoid(min_in)).log() + log_one_minus_cdf_min = -F.softplus(min_in) + + # probability for all other cases + cdf_delta = cdf_plus - cdf_min + + mid_in = inv_stdv * centered_y + # log probability in the center of the bin, to be used in extreme cases + log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in) + + inner_inner_cond = (cdf_delta > 1e-5).float() + + inner_inner_out = inner_inner_cond * torch.log( + torch.clamp(cdf_delta, min=1e-12) + ) + (1.0 - inner_inner_cond) * ( + log_pdf_mid - math.log((self.num_classes - 1) / 2) + ) + inner_cond = (y > 0.999).float() + inner_out = ( + inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out + ) + cond = (y < -0.999).float() + log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out + + log_probs = log_probs + F.log_softmax(logit_probs, -1) + + if self.reduce: + return -torch.mean(_log_sum_exp(log_probs)) + else: + return -_log_sum_exp(log_probs).unsqueeze(-1) + + +def _log_sum_exp(x): + r""" Numerically stable log_sum_exp implementation that prevents overflow + """ + + axis = len(x.size()) - 1 + m, _ = torch.max(x, dim=axis) + m2, _ = torch.max(x, dim=axis, keepdim=True) + return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) diff --git a/examples/pipeline_wavernn/main.py b/examples/pipeline_wavernn/main.py new file mode 100644 index 0000000000000000000000000000000000000000..3d561f091c313149e08d0c890f2c6cf524a4cf8d --- /dev/null +++ b/examples/pipeline_wavernn/main.py @@ -0,0 +1,399 @@ +import argparse +import logging +import os +from collections import defaultdict +from datetime import datetime +from time import time +from typing import List + +import torch +import torchaudio +from torch.optim import Adam +from torch.utils.data import DataLoader +from torchaudio.datasets.utils import bg_iterator +from torchaudio.models.wavernn import WaveRNN + +from datasets import collate_factory, split_process_dataset +from losses import LongCrossEntropyLoss, MoLLoss +from processing import NormalizeDB +from utils import MetricLogger, count_parameters, save_checkpoint + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--workers", + default=4, + type=int, + metavar="N", + help="number of data loading workers", + ) + parser.add_argument( + "--checkpoint", + default="", + type=str, + metavar="PATH", + help="path to latest checkpoint", + ) + parser.add_argument( + "--epochs", + default=8000, + type=int, + metavar="N", + help="number of total epochs to run", + ) + parser.add_argument( + "--start-epoch", default=0, type=int, metavar="N", help="manual epoch number" + ) + parser.add_argument( + "--print-freq", + default=10, + type=int, + metavar="N", + help="print frequency in epochs", + ) + parser.add_argument( + "--dataset", + default="ljspeech", + choices=["ljspeech", "libritts"], + type=str, + help="select dataset to train with", + ) + parser.add_argument( + "--batch-size", default=256, type=int, metavar="N", help="mini-batch size" + ) + parser.add_argument( + "--learning-rate", default=1e-4, type=float, metavar="LR", help="learning rate", + ) + parser.add_argument("--clip-grad", metavar="NORM", type=float, default=4.0) + parser.add_argument( + "--mulaw", + default=True, + action="store_true", + help="if used, waveform is mulaw encoded", + ) + parser.add_argument( + "--jit", default=False, action="store_true", help="if used, model is jitted" + ) + parser.add_argument( + "--upsample-scales", + default=[5, 5, 11], + type=List[int], + help="the list of upsample scales", + ) + parser.add_argument( + "--n-bits", default=8, type=int, help="the bits of output waveform", + ) + parser.add_argument( + "--sample-rate", + default=22050, + type=int, + help="the rate of audio dimensions (samples per second)", + ) + parser.add_argument( + "--hop-length", + default=275, + type=int, + help="the number of samples between the starts of consecutive frames", + ) + parser.add_argument( + "--win-length", default=1100, type=int, help="the length of the STFT window", + ) + parser.add_argument( + "--f-min", default=40.0, type=float, help="the minimum frequency", + ) + parser.add_argument( + "--min-level-db", + default=-100, + type=float, + help="the minimum db value for spectrogam normalization", + ) + parser.add_argument( + "--n-res-block", default=10, type=int, help="the number of ResBlock in stack", + ) + parser.add_argument( + "--n-rnn", default=512, type=int, help="the dimension of RNN layer", + ) + parser.add_argument( + "--n-fc", default=512, type=int, help="the dimension of fully connected layer", + ) + parser.add_argument( + "--kernel-size", + default=5, + type=int, + help="the number of kernel size in the first Conv1d layer", + ) + parser.add_argument( + "--n-freq", default=80, type=int, help="the number of spectrogram bins to use", + ) + parser.add_argument( + "--n-hidden-melresnet", + default=128, + type=int, + help="the number of hidden dimensions of resblock in melresnet", + ) + parser.add_argument( + "--n-output-melresnet", default=128, type=int, help="the output dimension of melresnet", + ) + parser.add_argument( + "--n-fft", default=2048, type=int, help="the number of Fourier bins", + ) + parser.add_argument( + "--loss", + default="crossentropy", + choices=["crossentropy", "mol"], + type=str, + help="the type of loss", + ) + parser.add_argument( + "--seq-len-factor", + default=5, + type=int, + help="the length of each waveform to process per batch = hop_length * seq_len_factor", + ) + parser.add_argument( + "--val-ratio", + default=0.1, + type=float, + help="the ratio of waveforms for validation", + ) + parser.add_argument( + "--file-path", default="", type=str, help="the path of audio files", + ) + parser.add_argument( + "--normalization", default=True, action="store_true", help="if True, spectrogram is normalized", + ) + + args = parser.parse_args() + return args + + +def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch): + + model.train() + + sums = defaultdict(lambda: 0.0) + start1 = time() + + metric = MetricLogger("train_iteration") + metric["epoch"] = epoch + + for waveform, specgram, target in bg_iterator(data_loader, maxsize=2): + + start2 = time() + + waveform = waveform.to(device) + specgram = specgram.to(device) + target = target.to(device) + + output = model(waveform, specgram) + output, target = output.squeeze(1), target.squeeze(1) + + loss = criterion(output, target) + loss_item = loss.item() + sums["loss"] += loss_item + metric["loss"] = loss_item + + optimizer.zero_grad() + loss.backward() + + if args.clip_grad > 0: + gradient = torch.nn.utils.clip_grad_norm_( + model.parameters(), args.clip_grad + ) + sums["gradient"] += gradient.item() + metric["gradient"] = gradient.item() + + optimizer.step() + + metric["iteration"] = sums["iteration"] + metric["time"] = time() - start2 + metric() + sums["iteration"] += 1 + + avg_loss = sums["loss"] / len(data_loader) + + metric = MetricLogger("train_epoch") + metric["epoch"] = epoch + metric["loss"] = sums["loss"] / len(data_loader) + metric["gradient"] = avg_loss + metric["time"] = time() - start1 + metric() + + +def validate(model, criterion, data_loader, device, epoch): + + with torch.no_grad(): + + model.eval() + sums = defaultdict(lambda: 0.0) + start = time() + + for waveform, specgram, target in bg_iterator(data_loader, maxsize=2): + + waveform = waveform.to(device) + specgram = specgram.to(device) + target = target.to(device) + + output = model(waveform, specgram) + output, target = output.squeeze(1), target.squeeze(1) + + loss = criterion(output, target) + sums["loss"] += loss.item() + + avg_loss = sums["loss"] / len(data_loader) + + metric = MetricLogger("validation") + metric["epoch"] = epoch + metric["loss"] = avg_loss + metric["time"] = time() - start + metric() + + return avg_loss + + +def main(args): + + devices = ["cuda" if torch.cuda.is_available() else "cpu"] + + logging.info("Start time: {}".format(str(datetime.now()))) + + melkwargs = { + "n_fft": args.n_fft, + "power": 1, + "hop_length": args.hop_length, + "win_length": args.win_length, + } + + transforms = torch.nn.Sequential( + torchaudio.transforms.MelSpectrogram( + sample_rate=args.sample_rate, + n_mels=args.n_freq, + f_min=args.f_min, + mel_scale='slaney', + norm='slaney', + **melkwargs, + ), + NormalizeDB(min_level_db=args.min_level_db, normalization=args.normalization), + ) + + train_dataset, val_dataset = split_process_dataset(args, transforms) + + loader_training_params = { + "num_workers": args.workers, + "pin_memory": False, + "shuffle": True, + "drop_last": False, + } + loader_validation_params = loader_training_params.copy() + loader_validation_params["shuffle"] = False + + collate_fn = collate_factory(args) + + train_loader = DataLoader( + train_dataset, + batch_size=args.batch_size, + collate_fn=collate_fn, + **loader_training_params, + ) + val_loader = DataLoader( + val_dataset, + batch_size=args.batch_size, + collate_fn=collate_fn, + **loader_validation_params, + ) + + n_classes = 2 ** args.n_bits if args.loss == "crossentropy" else 30 + + model = WaveRNN( + upsample_scales=args.upsample_scales, + n_classes=n_classes, + hop_length=args.hop_length, + n_res_block=args.n_res_block, + n_rnn=args.n_rnn, + n_fc=args.n_fc, + kernel_size=args.kernel_size, + n_freq=args.n_freq, + n_hidden=args.n_hidden_melresnet, + n_output=args.n_output_melresnet, + ) + + if args.jit: + model = torch.jit.script(model) + + model = torch.nn.DataParallel(model) + model = model.to(devices[0], non_blocking=True) + + n = count_parameters(model) + logging.info(f"Number of parameters: {n}") + + # Optimizer + optimizer_params = { + "lr": args.learning_rate, + } + + optimizer = Adam(model.parameters(), **optimizer_params) + + criterion = LongCrossEntropyLoss() if args.loss == "crossentropy" else MoLLoss() + + best_loss = 10.0 + + if args.checkpoint and os.path.isfile(args.checkpoint): + logging.info(f"Checkpoint: loading '{args.checkpoint}'") + checkpoint = torch.load(args.checkpoint) + + args.start_epoch = checkpoint["epoch"] + best_loss = checkpoint["best_loss"] + + model.load_state_dict(checkpoint["state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + + logging.info( + f"Checkpoint: loaded '{args.checkpoint}' at epoch {checkpoint['epoch']}" + ) + else: + logging.info("Checkpoint: not found") + + save_checkpoint( + { + "epoch": args.start_epoch, + "state_dict": model.state_dict(), + "best_loss": best_loss, + "optimizer": optimizer.state_dict(), + }, + False, + args.checkpoint, + ) + + for epoch in range(args.start_epoch, args.epochs): + + train_one_epoch( + model, criterion, optimizer, train_loader, devices[0], epoch, + ) + + if not (epoch + 1) % args.print_freq or epoch == args.epochs - 1: + + sum_loss = validate(model, criterion, val_loader, devices[0], epoch) + + is_best = sum_loss < best_loss + best_loss = min(sum_loss, best_loss) + save_checkpoint( + { + "epoch": epoch + 1, + "state_dict": model.state_dict(), + "best_loss": best_loss, + "optimizer": optimizer.state_dict(), + }, + is_best, + args.checkpoint, + ) + + logging.info(f"End time: {datetime.now()}") + + +if __name__ == "__main__": + + logging.basicConfig(level=logging.INFO) + args = parse_args() + main(args) diff --git a/examples/pipeline_wavernn/processing.py b/examples/pipeline_wavernn/processing.py new file mode 100644 index 0000000000000000000000000000000000000000..db2b1ee9f7ad0bc1b0c766c14a3acdf891ff9c0d --- /dev/null +++ b/examples/pipeline_wavernn/processing.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn + + +class NormalizeDB(nn.Module): + r"""Normalize the spectrogram with a minimum db value + """ + + def __init__(self, min_level_db, normalization): + super().__init__() + self.min_level_db = min_level_db + self.normalization = normalization + + def forward(self, specgram): + specgram = torch.log10(torch.clamp(specgram.squeeze(0), min=1e-5)) + if self.normalization: + return torch.clamp( + (self.min_level_db - 20 * specgram) / self.min_level_db, min=0, max=1 + ) + return specgram + + +def normalized_waveform_to_bits(waveform: torch.Tensor, bits: int) -> torch.Tensor: + r"""Transform waveform [-1, 1] to label [0, 2 ** bits - 1] + """ + + assert abs(waveform).max() <= 1.0 + waveform = (waveform + 1.0) * (2 ** bits - 1) / 2 + return torch.clamp(waveform, 0, 2 ** bits - 1).int() + + +def bits_to_normalized_waveform(label: torch.Tensor, bits: int) -> torch.Tensor: + r"""Transform label [0, 2 ** bits - 1] to waveform [-1, 1] + """ + + return 2 * label / (2 ** bits - 1.0) - 1.0 diff --git a/examples/pipeline_wavernn/utils.py b/examples/pipeline_wavernn/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e924c9f51250509fb1b95fd8182b27ebf34680fb --- /dev/null +++ b/examples/pipeline_wavernn/utils.py @@ -0,0 +1,61 @@ +import logging +import os +import shutil +from collections import defaultdict, deque + +import torch + + +class MetricLogger: + r"""Logger for model metrics + """ + + def __init__(self, group, print_freq=1): + self.print_freq = print_freq + self._iter = 0 + self.data = defaultdict(lambda: deque(maxlen=self.print_freq)) + self.data["group"].append(group) + + def __setitem__(self, key, value): + self.data[key].append(value) + + def _get_last(self): + return {k: v[-1] for k, v in self.data.items()} + + def __str__(self): + return str(self._get_last()) + + def __call__(self): + self._iter = (self._iter + 1) % self.print_freq + if not self._iter: + print(self, flush=True) + + +def save_checkpoint(state, is_best, filename): + r"""Save the model to a temporary file first, + then copy it to filename, in case the signal interrupts + the torch.save() process. + """ + + if filename == "": + return + + tempfile = filename + ".temp" + + # Remove tempfile in case interuption during the copying from tempfile to filename + if os.path.isfile(tempfile): + os.remove(tempfile) + + torch.save(state, tempfile) + if os.path.isfile(tempfile): + os.rename(tempfile, filename) + if is_best: + shutil.copyfile(filename, "model_best.pth.tar") + logging.info("Checkpoint: saved") + + +def count_parameters(model): + r"""Count the total number of parameters in the model + """ + + return sum(p.numel() for p in model.parameters() if p.requires_grad) diff --git a/examples/pipeline_wavernn/wavernn_inference_wrapper.py b/examples/pipeline_wavernn/wavernn_inference_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..5d5c4db71f8fe4488aaf01720566f2182ec0cdf2 --- /dev/null +++ b/examples/pipeline_wavernn/wavernn_inference_wrapper.py @@ -0,0 +1,181 @@ +# ***************************************************************************** +# Copyright (c) 2019 fatchord (https://github.com/fatchord) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ***************************************************************************** + + +from torchaudio.models.wavernn import WaveRNN +import torch +import torchaudio +from torch import Tensor + +from processing import normalized_waveform_to_bits + + +def _fold_with_overlap(x: Tensor, timesteps: int, overlap: int) -> Tensor: + r'''Fold the tensor with overlap for quick batched inference. + Overlap will be used for crossfading in xfade_and_unfold(). + + x = [[h1, h2, ... hn]] + Where each h is a vector of conditioning channels + Eg: timesteps=2, overlap=1 with x.size(1)=10 + folded = [[h1, h2, h3, h4], + [h4, h5, h6, h7], + [h7, h8, h9, h10]] + + Args: + x (tensor): Upsampled conditioning channels of size (1, timesteps, channel). + timesteps (int): Timesteps for each index of batch. + overlap (int): Timesteps for both xfade and rnn warmup. + + Return: + folded (tensor): folded tensor of size (n_folds, timesteps + 2 * overlap, channel). + ''' + + _, channels, total_len = x.size() + + # Calculate variables needed + n_folds = (total_len - overlap) // (timesteps + overlap) + extended_len = n_folds * (overlap + timesteps) + overlap + remaining = total_len - extended_len + + # Pad if some time steps poking out + if remaining != 0: + n_folds += 1 + padding = timesteps + 2 * overlap - remaining + x = torch.nn.functional.pad(x, (0, padding)) + + folded = torch.zeros((n_folds, channels, timesteps + 2 * overlap), device=x.device) + + # Get the values for the folded tensor + for i in range(n_folds): + start = i * (timesteps + overlap) + end = start + timesteps + 2 * overlap + folded[i] = x[0, :, start:end] + + return folded + + +def _xfade_and_unfold(y: Tensor, overlap: int) -> Tensor: + r'''Applies a crossfade and unfolds into a 1d array. + + y = [[seq1], + [seq2], + [seq3]] + Apply a gain envelope at both ends of the sequences + y = [[seq1_in, seq1_timesteps, seq1_out], + [seq2_in, seq2_timesteps, seq2_out], + [seq3_in, seq3_timesteps, seq3_out]] + Stagger and add up the groups of samples: + [seq1_in, seq1_timesteps, (seq1_out + seq2_in), seq2_timesteps, ...] + + Args: + y (Tensor): Batched sequences of audio samples of size + (num_folds, channels, timesteps + 2 * overlap). + overlap (int): Timesteps for both xfade and rnn warmup. + + Returns: + unfolded waveform (Tensor) : waveform in a 1d tensor of size (channels, total_len). + ''' + + num_folds, channels, length = y.shape + timesteps = length - 2 * overlap + total_len = num_folds * (timesteps + overlap) + overlap + + # Need some silence for the rnn warmup + silence_len = overlap // 2 + fade_len = overlap - silence_len + silence = torch.zeros((silence_len), dtype=y.dtype, device=y.device) + linear = torch.ones((silence_len), dtype=y.dtype, device=y.device) + + # Equal power crossfade + t = torch.linspace(-1, 1, fade_len, dtype=y.dtype, device=y.device) + fade_in = torch.sqrt(0.5 * (1 + t)) + fade_out = torch.sqrt(0.5 * (1 - t)) + + # Concat the silence to the fades + fade_in = torch.cat([silence, fade_in]) + fade_out = torch.cat([linear, fade_out]) + + # Apply the gain to the overlap samples + y[:, :, :overlap] *= fade_in + y[:, :, -overlap:] *= fade_out + + unfolded = torch.zeros((channels, total_len), dtype=y.dtype, device=y.device) + + # Loop to add up all the samples + for i in range(num_folds): + start = i * (timesteps + overlap) + end = start + timesteps + 2 * overlap + unfolded[:, start:end] += y[i] + + return unfolded + + +class WaveRNNInferenceWrapper(torch.nn.Module): + + def __init__(self, wavernn: WaveRNN): + super().__init__() + self.wavernn_model = wavernn + + def forward(self, + specgram: Tensor, + mulaw: bool = True, + batched: bool = True, + timesteps: int = 100, + overlap: int = 5) -> Tensor: + r"""Inference function for WaveRNN. + + Based on the implementation from + https://github.com/fatchord/WaveRNN/blob/master/models/fatchord_version.py. + + + Currently only supports multinomial sampling. + + Args: + specgram (Tensor): spectrogram of size (n_mels, n_time) + mulaw (bool, optional): Whether to perform mulaw decoding (Default: ``True``). + batched (bool, optional): Whether to perform batch prediction. Using batch prediction + will significantly increase the inference speed (Default: ``True``). + timesteps (int, optional): The time steps for each batch. Only used when `batched` + is set to True (Default: ``100``). + overlap (int, optional): The overlapping time steps between batches. Only used when + `batched` is set to True (Default: ``5``). + + Returns: + waveform (Tensor): Reconstructed waveform of size (1, n_time, ). + 1 represents single channel. + """ + specgram = specgram.unsqueeze(0) + if batched: + specgram = _fold_with_overlap(specgram, timesteps, overlap) + + output = self.wavernn_model.infer(specgram).cpu() + + if mulaw: + output = normalized_waveform_to_bits(output, self.wavernn_model.n_bits) + output = torchaudio.functional.mu_law_decoding(output, self.wavernn_model.n_classes) + + if batched: + output = _xfade_and_unfold(output, overlap) + else: + output = output[0] + + return output diff --git a/examples/source_separation/README.md b/examples/source_separation/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4f2487a5c6cc04502edb358675a41f1cb0e17091 --- /dev/null +++ b/examples/source_separation/README.md @@ -0,0 +1,76 @@ +# Source Separation Example + +This directory contains reference implementations for source separations. For the detail of each model, please checkout the followings. + +- [Conv-TasNet](./conv_tasnet/README.md) + +## Usage + +### Overview + +To training a model, you can use [`lightning_train.py`](./lightning_train.py). This script takes the form of +`lightning_train.py [parameters]` + + ``` + python lightning_train.py \ + [--data-dir DATA_DIR] \ + [--num-gpu NUM_GPU] \ + [--num-workers NUM_WORKERS] \ + ... + + # For the detail of the parameter values, use; + python lightning_train.py --help + ``` + +This script runs training in PyTorch-Lightning framework with Distributed Data Parallel (DDP) backend. +### SLURM + +
Example scripts for running the training on SLURM cluster + +- **launch_job.sh** + +```bash +#!/bin/bash + +#SBATCH --job-name=source_separation + +#SBATCH --output=/checkpoint/%u/jobs/%x/%j.out + +#SBATCH --error=/checkpoint/%u/jobs/%x/%j.err + +#SBATCH --nodes=1 + +#SBATCH --ntasks-per-node=2 + +#SBATCH --cpus-per-task=8 + +#SBATCH --mem-per-cpu=16G + +#SBATCH --gpus-per-node=2 + +#srun env +srun wrapper.sh $@ +``` + +- **wrapper.sh** + +```bash +#!/bin/bash +num_speakers=2 +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +exp_dir="/checkpoint/${USER}/exp/" +dataset_dir="/dataset/Libri${num_speakers}mix//wav8k/min" + + +mkdir -p "${exp_dir}" + +python -u \ + "${this_dir}/lightning_train.py" \ + --num-speakers "${num_speakers}" \ + --sample-rate 8000 \ + --data-dir "${dataset_dir}" \ + --exp-dir "${exp_dir}" \ + --batch-size $((16 / SLURM_NTASKS)) +``` + +
diff --git a/examples/source_separation/conv_tasnet/README.md b/examples/source_separation/conv_tasnet/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8c64ac6d98ee68f98741eecca380d8259f0a3f09 --- /dev/null +++ b/examples/source_separation/conv_tasnet/README.md @@ -0,0 +1,44 @@ +# Conv-TasNet + +This is a reference implementation of Conv-TasNet. + +> Luo, Yi, and Nima Mesgarani. "Conv-TasNet: Surpassing Ideal Time-Frequency Magnitude Masking for Speech Separation." IEEE/ACM Transactions on Audio, Speech, and Language Processing 27.8 (2019): 1256-1266. Crossref. Web. + +This implementation is based on [arXiv:1809.07454v3](https://arxiv.org/abs/1809.07454v3) and [the reference implementation](https://github.com/naplab/Conv-TasNet) provided by the authors. + +For the usage, please checkout the [source separation README](../README.md). + +## (Default) Training Configurations + +The default training/model configurations follow the non-causal implementation from [Asteroid](https://github.com/asteroid-team/asteroid/tree/master/egs/librimix/ConvTasNet). (causal configuration is not implemented.) + + - Sample rate: 8000 Hz + - Batch size: total 12 over distributed training workers + - Epochs: 200 + - Initial learning rate: 1e-3 + - Gradient clipping: maximum L2 norm of 5.0 + - Optimizer: Adam + - Learning rate scheduling: Halved after 5 epochs of no improvement in validation accuracy. + - Objective function: SI-SNR + - Reported metrics: SI-SNRi, SDRi + - Sample audio length: 3 seconds (randomized position) + - Encoder/Decoder feature dimension (N): 512 + - Encoder/Decoder convolution kernel size (L): 16 + - TCN bottleneck/output feature dimension (B): 128 + - TCN hidden feature dimension (H): 512 + - TCN skip connection feature dimension (Sc): 128 + - TCN convolution kernel size (P): 3 + - The number of TCN convolution block layers (X): 8 + - The number of TCN convolution blocks (R): 3 + - The mask activation function: ReLU + +## Evaluation + +The following is the evaluation result of training the model on Libri2Mix dataset. + +### LibirMix 2speakers + +| | Si-SNRi (dB) | SDRi (dB) | Epoch | +|:-------------------:|-------------:|----------:|------:| +| Reference (Asteroid)| 14.7 | 15.1 | 200 | +| torchaudio | 15.3 | 15.6 | 200 | diff --git a/examples/source_separation/conv_tasnet/__init__.py b/examples/source_separation/conv_tasnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9c31ed253fa0062f0c3ff77b4a7677f81ae41888 --- /dev/null +++ b/examples/source_separation/conv_tasnet/__init__.py @@ -0,0 +1,6 @@ +from . import ( + train, + trainer +) + +__all__ = ['train', 'trainer'] diff --git a/examples/source_separation/conv_tasnet/train.py b/examples/source_separation/conv_tasnet/train.py new file mode 100644 index 0000000000000000000000000000000000000000..691e7f3415ba64fa6dc45ab56cd46be720d85ea9 --- /dev/null +++ b/examples/source_separation/conv_tasnet/train.py @@ -0,0 +1,332 @@ +#!/usr/bin/env python3 +"""Train Conv-TasNet""" +import time +import pathlib +import argparse + +import torch +import torchaudio +import torchaudio.models + +import conv_tasnet +from utils import dist_utils +from utils.dataset import utils as dataset_utils + +_LG = dist_utils.getLogger(__name__) + + +def _parse_args(args): + parser = argparse.ArgumentParser(description=__doc__,) + parser.add_argument( + "--debug", + action="store_true", + help="Enable debug behavior. Each epoch will end with just one batch.") + group = parser.add_argument_group("Model Options") + group.add_argument( + "--num-speakers", required=True, type=int, help="The number of speakers." + ) + group = parser.add_argument_group("Dataset Options") + group.add_argument( + "--sample-rate", + required=True, + type=int, + help="Sample rate of audio files in the given dataset.", + ) + group.add_argument( + "--dataset", + default="wsj0mix", + choices=["wsj0mix"], + help='Dataset type. (default: "wsj0mix")', + ) + group.add_argument( + "--dataset-dir", + required=True, + type=pathlib.Path, + help=( + "Directory where dataset is found. " + 'If the dataset type is "wsj9mix", then this is the directory where ' + '"cv", "tt" and "tr" subdirectories are found.' + ), + ) + group = parser.add_argument_group("Save Options") + group.add_argument( + "--save-dir", + required=True, + type=pathlib.Path, + help=( + "Directory where the checkpoints and logs are saved. " + "Though, only the worker 0 saves checkpoint data, " + "all the worker processes must have access to the directory." + ), + ) + group = parser.add_argument_group("Dataloader Options") + group.add_argument( + "--batch-size", + type=int, + help="Batch size. (default: 16 // world_size)", + ) + group = parser.add_argument_group("Training Options") + group.add_argument( + "--epochs", + metavar="NUM_EPOCHS", + default=100, + type=int, + help="The number of epochs to train. (default: 100)", + ) + group.add_argument( + "--learning-rate", + default=1e-3, + type=float, + help="Initial learning rate. (default: 1e-3)", + ) + group.add_argument( + "--grad-clip", + metavar="CLIP_VALUE", + default=5.0, + type=float, + help="Gradient clip value (l2 norm). (default: 5.0)", + ) + group.add_argument( + "--resume", + metavar="CHECKPOINT_PATH", + help="Previous checkpoint file from which the training is resumed.", + ) + + args = parser.parse_args(args) + + # Delaing the default value initialization until parse_args is done because + # if `--help` is given, distributed training is not enabled. + if args.batch_size is None: + args.batch_size = 16 // torch.distributed.get_world_size() + + return args + + +def _get_model( + num_sources, + enc_kernel_size=16, + enc_num_feats=512, + msk_kernel_size=3, + msk_num_feats=128, + msk_num_hidden_feats=512, + msk_num_layers=8, + msk_num_stacks=3, +): + model = torchaudio.models.ConvTasNet( + num_sources=num_sources, + enc_kernel_size=enc_kernel_size, + enc_num_feats=enc_num_feats, + msk_kernel_size=msk_kernel_size, + msk_num_feats=msk_num_feats, + msk_num_hidden_feats=msk_num_hidden_feats, + msk_num_layers=msk_num_layers, + msk_num_stacks=msk_num_stacks, + ) + _LG.info_on_master("Model Configuration:") + _LG.info_on_master(" - N: %d", enc_num_feats) + _LG.info_on_master(" - L: %d", enc_kernel_size) + _LG.info_on_master(" - B: %d", msk_num_feats) + _LG.info_on_master(" - H: %d", msk_num_hidden_feats) + _LG.info_on_master(" - Sc: %d", msk_num_feats) + _LG.info_on_master(" - P: %d", msk_kernel_size) + _LG.info_on_master(" - X: %d", msk_num_layers) + _LG.info_on_master(" - R: %d", msk_num_stacks) + _LG.info_on_master( + " - Receptive Field: %s [samples]", model.mask_generator.receptive_field, + ) + return model + + +def _get_dataloader(dataset_type, dataset_dir, num_speakers, sample_rate, batch_size, task=None): + train_dataset, valid_dataset, eval_dataset = dataset_utils.get_dataset( + dataset_type, dataset_dir, num_speakers, sample_rate, task + ) + train_collate_fn = dataset_utils.get_collate_fn( + dataset_type, mode='train', sample_rate=sample_rate, duration=4 + ) + + test_collate_fn = dataset_utils.get_collate_fn(dataset_type, mode='test') + + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=batch_size, + sampler=torch.utils.data.distributed.DistributedSampler(train_dataset), + collate_fn=train_collate_fn, + pin_memory=True, + ) + valid_loader = torch.utils.data.DataLoader( + valid_dataset, + batch_size=batch_size, + sampler=torch.utils.data.distributed.DistributedSampler(valid_dataset), + collate_fn=test_collate_fn, + pin_memory=True, + ) + eval_loader = torch.utils.data.DataLoader( + eval_dataset, + batch_size=batch_size, + sampler=torch.utils.data.distributed.DistributedSampler(eval_dataset), + collate_fn=test_collate_fn, + pin_memory=True, + ) + return train_loader, valid_loader, eval_loader + + +def _write_header(log_path, args): + rows = [ + [f"# torch: {torch.__version__}", ], + [f"# torchaudio: {torchaudio.__version__}", ] + ] + rows.append(["# arguments"]) + for key, item in vars(args).items(): + rows.append([f"# {key}: {item}"]) + + dist_utils.write_csv_on_master(log_path, *rows) + + +def train(args): + args = _parse_args(args) + _LG.info("%s", args) + + args.save_dir.mkdir(parents=True, exist_ok=True) + if "sox_io" in torchaudio.list_audio_backends(): + torchaudio.set_audio_backend("sox_io") + + start_epoch = 1 + if args.resume: + checkpoint = torch.load(args.resume) + if args.sample_rate != checkpoint["sample_rate"]: + raise ValueError( + "The provided sample rate ({args.sample_rate}) does not match " + "the sample rate from the check point ({checkpoint['sample_rate']})." + ) + if args.num_speakers != checkpoint["num_speakers"]: + raise ValueError( + "The provided #of speakers ({args.num_speakers}) does not match " + "the #of speakers from the check point ({checkpoint['num_speakers']}.)" + ) + start_epoch = checkpoint["epoch"] + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + _LG.info("Using: %s", device) + + model = _get_model(num_sources=args.num_speakers) + model.to(device) + + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[device] if torch.cuda.is_available() else None + ) + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + + if args.resume: + _LG.info("Loading parameters from the checkpoint...") + model.module.load_state_dict(checkpoint["model"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + else: + dist_utils.synchronize_params( + str(args.save_dir / "tmp.pt"), device, model, optimizer + ) + + lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="max", factor=0.5, patience=3 + ) + + train_loader, valid_loader, eval_loader = _get_dataloader( + args.dataset, + args.dataset_dir, + args.num_speakers, + args.sample_rate, + args.batch_size, + ) + + num_train_samples = len(train_loader.dataset) + num_valid_samples = len(valid_loader.dataset) + num_eval_samples = len(eval_loader.dataset) + + _LG.info_on_master("Datasets:") + _LG.info_on_master(" - Train: %s", num_train_samples) + _LG.info_on_master(" - Valid: %s", num_valid_samples) + _LG.info_on_master(" - Eval: %s", num_eval_samples) + + trainer = conv_tasnet.trainer.Trainer( + model, + optimizer, + train_loader, + valid_loader, + eval_loader, + args.grad_clip, + device, + debug=args.debug, + ) + + log_path = args.save_dir / "log.csv" + _write_header(log_path, args) + dist_utils.write_csv_on_master( + log_path, + [ + "epoch", + "learning_rate", + "valid_si_snri", + "valid_sdri", + "eval_si_snri", + "eval_sdri", + ], + ) + + _LG.info_on_master("Running %s epochs", args.epochs) + for epoch in range(start_epoch, start_epoch + args.epochs): + _LG.info_on_master("=" * 70) + _LG.info_on_master("Epoch: %s", epoch) + _LG.info_on_master("Learning rate: %s", optimizer.param_groups[0]["lr"]) + _LG.info_on_master("=" * 70) + + t0 = time.monotonic() + trainer.train_one_epoch() + train_sps = num_train_samples / (time.monotonic() - t0) + + _LG.info_on_master("-" * 70) + + t0 = time.monotonic() + valid_metric = trainer.validate() + valid_sps = num_valid_samples / (time.monotonic() - t0) + _LG.info_on_master("Valid: %s", valid_metric) + + _LG.info_on_master("-" * 70) + + t0 = time.monotonic() + eval_metric = trainer.evaluate() + eval_sps = num_eval_samples / (time.monotonic() - t0) + _LG.info_on_master(" Eval: %s", eval_metric) + + _LG.info_on_master("-" * 70) + + _LG.info_on_master("Train: Speed: %6.2f [samples/sec]", train_sps) + _LG.info_on_master("Valid: Speed: %6.2f [samples/sec]", valid_sps) + _LG.info_on_master(" Eval: Speed: %6.2f [samples/sec]", eval_sps) + + _LG.info_on_master("-" * 70) + + dist_utils.write_csv_on_master( + log_path, + [ + epoch, + optimizer.param_groups[0]["lr"], + valid_metric.si_snri, + valid_metric.sdri, + eval_metric.si_snri, + eval_metric.sdri, + ], + ) + + lr_scheduler.step(valid_metric.si_snri) + + save_path = args.save_dir / f"epoch_{epoch}.pt" + dist_utils.save_on_master( + save_path, + { + "model": model.module.state_dict(), + "optimizer": optimizer.state_dict(), + "num_speakers": args.num_speakers, + "sample_rate": args.sample_rate, + "epoch": epoch, + }, + ) diff --git a/examples/source_separation/conv_tasnet/trainer.py b/examples/source_separation/conv_tasnet/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..af4e715305b4ac2df238f41c20f8b12e0a2e1928 --- /dev/null +++ b/examples/source_separation/conv_tasnet/trainer.py @@ -0,0 +1,165 @@ +import time +from typing import Tuple +from collections import namedtuple + +import torch +import torch.distributed as dist + +from utils import dist_utils, metrics + +_LG = dist_utils.getLogger(__name__) + +Metric = namedtuple("SNR", ["si_snri", "sdri"]) +Metric.__str__ = ( + lambda self: f"SI-SNRi: {self.si_snri:10.3e}, SDRi: {self.sdri:10.3e}" +) + + +def si_sdr_improvement( + estimate: torch.Tensor, + reference: torch.Tensor, + mix: torch.Tensor, + mask: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute the improvement of scale-invariant SDR. (SI-SNRi) and bare SDR (SDRi). + + Args: + estimate (torch.Tensor): Estimated source signals. + Shape: [batch, speakers, time frame] + reference (torch.Tensor): Reference (original) source signals. + Shape: [batch, speakers, time frame] + mix (torch.Tensor): Mixed souce signals, from which the setimated signals were generated. + Shape: [batch, speakers == 1, time frame] + mask (torch.Tensor): Mask to indicate padded value (0) or valid value (1). + Shape: [batch, 1, time frame] + + + Returns: + torch.Tensor: Improved SI-SDR. Shape: [batch, ] + torch.Tensor: Absolute SI-SDR. Shape: [batch, ] + + References: + - Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation + Luo, Yi and Mesgarani, Nima + https://arxiv.org/abs/1809.07454 + """ + with torch.no_grad(): + sdri = metrics.sdri(estimate, reference, mix, mask=mask) + + estimate = estimate - estimate.mean(axis=2, keepdim=True) + reference = reference - reference.mean(axis=2, keepdim=True) + mix = mix - mix.mean(axis=2, keepdim=True) + + si_sdri = metrics.sdri(estimate, reference, mix, mask=mask) + return si_sdri, sdri + + +class OccasionalLogger: + """Simple helper class to log once in a while or when progress is quick enough""" + + def __init__(self, time_interval=180, progress_interval=0.1): + self.time_interval = time_interval + self.progress_interval = progress_interval + + self.last_time = 0.0 + self.last_progress = 0.0 + + def log(self, metric, progress, force=False): + now = time.monotonic() + if ( + force + or now > self.last_time + self.time_interval + or progress > self.last_progress + self.progress_interval + ): + self.last_time = now + self.last_progress = progress + _LG.info_on_master("train: %s [%3d%%]", metric, 100 * progress) + + +class Trainer: + def __init__( + self, + model, + optimizer, + train_loader, + valid_loader, + eval_loader, + grad_clip, + device, + *, + debug, + ): + self.model = model + self.optimizer = optimizer + self.train_loader = train_loader + self.valid_loader = valid_loader + self.eval_loader = eval_loader + self.grad_clip = grad_clip + self.device = device + self.debug = debug + + def train_one_epoch(self): + self.model.train() + logger = OccasionalLogger() + + num_batches = len(self.train_loader) + for i, batch in enumerate(self.train_loader, start=1): + mix = batch.mix.to(self.device) + src = batch.src.to(self.device) + mask = batch.mask.to(self.device) + + estimate = self.model(mix) + + si_snri, sdri = si_sdr_improvement(estimate, src, mix, mask) + si_snri = si_snri.mean() + sdri = sdri.mean() + + loss = -si_snri + self.optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.grad_clip, norm_type=2.0 + ) + self.optimizer.step() + + metric = Metric(si_snri.item(), sdri.item()) + logger.log(metric, progress=i / num_batches, force=i == num_batches) + + if self.debug: + break + + def evaluate(self): + with torch.no_grad(): + return self._test(self.eval_loader) + + def validate(self): + with torch.no_grad(): + return self._test(self.valid_loader) + + def _test(self, loader): + self.model.eval() + + total_si_snri = torch.zeros(1, dtype=torch.float32, device=self.device) + total_sdri = torch.zeros(1, dtype=torch.float32, device=self.device) + + for batch in loader: + mix = batch.mix.to(self.device) + src = batch.src.to(self.device) + mask = batch.mask.to(self.device) + + estimate = self.model(mix) + + si_snri, sdri = si_sdr_improvement(estimate, src, mix, mask) + + total_si_snri += si_snri.sum() + total_sdri += sdri.sum() + + if self.debug: + break + + dist.all_reduce(total_si_snri, dist.ReduceOp.SUM) + dist.all_reduce(total_sdri, dist.ReduceOp.SUM) + + num_samples = len(loader.dataset) + metric = Metric(total_si_snri.item() / num_samples, total_sdri.item() / num_samples) + return metric diff --git a/examples/source_separation/eval.py b/examples/source_separation/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..c8ab87ea118ea5ef9cdeba7b741f44a3c696d30d --- /dev/null +++ b/examples/source_separation/eval.py @@ -0,0 +1,106 @@ +from argparse import ArgumentParser +from pathlib import Path + +from lightning_train import _get_model, _get_dataloader, sisdri_metric +import mir_eval +import torch + + +def _eval(model, data_loader, device): + results = torch.zeros(4) + with torch.no_grad(): + for _, batch in enumerate(data_loader): + mix, src, mask = batch + mix, src, mask = mix.to(device), src.to(device), mask.to(device) + est = model(mix) + sisdri = sisdri_metric(est, src, mix, mask) + src = src.cpu().detach().numpy() + est = est.cpu().detach().numpy() + mix = mix.repeat(1, src.shape[1], 1).cpu().detach().numpy() + sdr, sir, sar, _ = mir_eval.separation.bss_eval_sources(src[0], est[0]) + sdr_mix, sir_mix, sar_mix, _ = mir_eval.separation.bss_eval_sources(src[0], mix[0]) + results += torch.tensor([ + sdr.mean() - sdr_mix.mean(), + sisdri, + sir.mean() - sir_mix.mean(), + sar.mean() - sar_mix.mean() + ]) + results /= len(data_loader) + print("SDR improvement: ", results[0].item()) + print("Si-SDR improvement: ", results[1].item()) + print("SIR improvement: ", results[2].item()) + print("SAR improvement: ", results[3].item()) + + +def cli_main(): + parser = ArgumentParser() + parser.add_argument("--dataset", default="librimix", type=str, choices=["wsj0-mix", "librimix"]) + parser.add_argument( + "--root-dir", + type=Path, + help="The path to the directory where the directory ``Libri2Mix`` or ``Libri3Mix`` is stored.", + ) + parser.add_argument( + "--librimix-tr-split", + default="train-360", + choices=["train-360", "train-100"], + help="The training partition of librimix dataset. (default: ``train-360``)", + ) + parser.add_argument( + "--librimix-task", + default="sep_clean", + type=str, + choices=["sep_clean", "sep_noisy", "enh_single", "enh_both"], + help="The task to perform (separation or enhancement, noisy or clean). (default: ``sep_clean``)", + ) + parser.add_argument( + "--num-speakers", default=2, type=int, help="The number of speakers in the mixture. (default: 2)" + ) + parser.add_argument( + "--sample-rate", + default=8000, + type=int, + help="Sample rate of audio files in the given dataset. (default: 8000)", + ) + parser.add_argument( + "--exp-dir", + default=Path("./exp"), + type=Path, + help="The directory to save checkpoints and logs." + ) + parser.add_argument( + "--gpu-device", + default=-1, + type=int, + help="The gpu device for model inference. (default: -1)" + ) + + args = parser.parse_args() + + model = _get_model(num_sources=2) + state_dict = torch.load(args.exp_dir / 'best_model.pth') + model.load_state_dict(state_dict) + + if args.gpu_device != -1: + device = torch.device('cuda:' + str(args.gpu_device)) + else: + device = torch.device('cpu') + + model = model.to(device) + + _, _, eval_loader = _get_dataloader( + args.dataset, + args.data_dir, + args.num_speakers, + args.sample_rate, + 1, # batch size is set to 1 to avoid masking + 0, # set num_workers to 0 + args.librimix_task, + args.librimix_tr_split, + ) + + _eval(model, eval_loader, device) + + +if __name__ == "__main__": + cli_main() diff --git a/examples/source_separation/lightning_train.py b/examples/source_separation/lightning_train.py new file mode 100644 index 0000000000000000000000000000000000000000..a0d61e8a5731d6e66c0eb2947916a09a99da26f3 --- /dev/null +++ b/examples/source_separation/lightning_train.py @@ -0,0 +1,465 @@ +#!/usr/bin/env python3 + +# pyre-strict +from pathlib import Path +from argparse import ArgumentParser +from typing import ( + Any, + Callable, + Dict, + Mapping, + List, + Optional, + Tuple, + TypedDict, + Union, +) + +import torch +import torchaudio +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping +from pytorch_lightning.plugins import DDPPlugin +from torch import nn +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader +from utils import metrics +from utils.dataset import utils as dataset_utils + + +class Batch(TypedDict): + mix: torch.Tensor # (batch, time) + src: torch.Tensor # (batch, source, time) + mask: torch.Tensor # (batch, source, time) + + +def sisdri_metric( + estimate: torch.Tensor, + reference: torch.Tensor, + mix: torch.Tensor, + mask: torch.Tensor +) -> torch.Tensor: + """Compute the improvement of scale-invariant SDR. (SI-SDRi). + + Args: + estimate (torch.Tensor): Estimated source signals. + Tensor of dimension (batch, speakers, time) + reference (torch.Tensor): Reference (original) source signals. + Tensor of dimension (batch, speakers, time) + mix (torch.Tensor): Mixed souce signals, from which the setimated signals were generated. + Tensor of dimension (batch, speakers == 1, time) + mask (torch.Tensor): Mask to indicate padded value (0) or valid value (1). + Tensor of dimension (batch, 1, time) + + Returns: + torch.Tensor: Improved SI-SDR. Tensor of dimension (batch, ) + + References: + - Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation + Luo, Yi and Mesgarani, Nima + https://arxiv.org/abs/1809.07454 + """ + with torch.no_grad(): + estimate = estimate - estimate.mean(axis=2, keepdim=True) + reference = reference - reference.mean(axis=2, keepdim=True) + mix = mix - mix.mean(axis=2, keepdim=True) + + si_sdri = metrics.sdri(estimate, reference, mix, mask=mask) + + return si_sdri.mean().item() + + +def sdri_metric( + estimate: torch.Tensor, + reference: torch.Tensor, + mix: torch.Tensor, + mask: torch.Tensor, +) -> torch.Tensor: + """Compute the improvement of SDR. (SDRi). + + Args: + estimate (torch.Tensor): Estimated source signals. + Tensor of dimension (batch, speakers, time) + reference (torch.Tensor): Reference (original) source signals. + Tensor of dimension (batch, speakers, time) + mix (torch.Tensor): Mixed souce signals, from which the setimated signals were generated. + Tensor of dimension (batch, speakers == 1, time) + mask (torch.Tensor): Mask to indicate padded value (0) or valid value (1). + Tensor of dimension (batch, 1, time) + + Returns: + torch.Tensor: Improved SDR. Tensor of dimension (batch, ) + + References: + - Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation + Luo, Yi and Mesgarani, Nima + https://arxiv.org/abs/1809.07454 + """ + with torch.no_grad(): + sdri = metrics.sdri(estimate, reference, mix, mask=mask) + return sdri.mean().item() + + +def si_sdr_loss( + estimate: torch.Tensor, + reference: torch.Tensor, + mask: torch.Tensor +) -> torch.Tensor: + """Compute the Si-SDR loss. + + Args: + estimate (torch.Tensor): Estimated source signals. + Tensor of dimension (batch, speakers, time) + reference (torch.Tensor): Reference (original) source signals. + Tensor of dimension (batch, speakers, time) + mask (torch.Tensor): Mask to indicate padded value (0) or valid value (1). + Tensor of dimension (batch, 1, time) + + Returns: + torch.Tensor: Si-SDR loss. Tensor of dimension (batch, ) + """ + estimate = estimate - estimate.mean(axis=2, keepdim=True) + reference = reference - reference.mean(axis=2, keepdim=True) + + si_sdri = metrics.sdr_pit(estimate, reference, mask=mask) + return -si_sdri.mean() + + +class ConvTasNetModule(LightningModule): + """ + The Lightning Module for speech separation. + + Args: + model (Any): The model to use for the classification task. + train_loader (DataLoader): the training dataloader. + val_loader (DataLoader or None): the validation dataloader. + loss (Any): The loss function to use. + optim (Any): The optimizer to use. + metrics (List of methods): The metrics to track, which will be used for both train and validation. + lr_scheduler (Any or None): The LR Scheduler. + """ + + def __init__( + self, + model: Any, + train_loader: DataLoader, + val_loader: Optional[DataLoader], + loss: Any, + optim: Any, + metrics: List[Any], + lr_scheduler: Optional[Any] = None, + ) -> None: + super().__init__() + + self.model: nn.Module = model + self.loss: nn.Module = loss + self.optim: torch.optim.Optimizer = optim + self.lr_scheduler: Optional[_LRScheduler] = None + if lr_scheduler: + self.lr_scheduler = lr_scheduler + + self.metrics: Mapping[str, Callable] = metrics + + self.train_metrics: Dict = {} + self.val_metrics: Dict = {} + self.test_metrics: Dict = {} + + self.save_hyperparameters() + self.train_loader = train_loader + self.val_loader = val_loader + + def setup(self, stage: Optional[str] = None) -> None: + if stage == "fit": + self.train_metrics.update(self.metrics) + self.val_metrics.update(self.metrics) + else: + self.test_metrics.update(self.metrics) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward defines the prediction/inference actions. + """ + return self.model(x) + + def training_step( + self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any + ) -> Dict[str, Any]: + return self._step(batch, batch_idx, "train") + + def validation_step( + self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any + ) -> Dict[str, Any]: + """ + Operates on a single batch of data from the validation set. + """ + return self._step(batch, batch_idx, "val") + + def test_step( + self, batch: Batch, batch_idx: int, *args: Any, **kwargs: Any + ) -> Optional[Dict[str, Any]]: + """ + Operates on a single batch of data from the test set. + """ + return self._step(batch, batch_idx, "test") + + def _step(self, batch: Batch, batch_idx: int, phase_type: str) -> Dict[str, Any]: + """ + Common step for training, validation, and testing. + """ + mix, src, mask = batch + pred = self.model(mix) + loss = self.loss(pred, src, mask) + self.log(f"Losses/{phase_type}_loss", loss.item(), on_step=True, on_epoch=True) + + metrics_result = self._compute_metrics(pred, src, mix, mask, phase_type) + self.log_dict(metrics_result, on_epoch=True) + + return loss + + def configure_optimizers( + self, + ) -> Tuple[Any]: + lr_scheduler = self.lr_scheduler + if not lr_scheduler: + return self.optim + epoch_schedulers = { + 'scheduler': lr_scheduler, + 'monitor': 'Losses/val_loss', + 'interval': 'epoch' + } + return [self.optim], [epoch_schedulers] + + def _compute_metrics( + self, + pred: torch.Tensor, + label: torch.Tensor, + inputs: torch.Tensor, + mask: torch.Tensor, + phase_type: str, + ) -> Dict[str, torch.Tensor]: + metrics_dict = getattr(self, f"{phase_type}_metrics") + metrics_result = {} + for name, metric in metrics_dict.items(): + metrics_result[f"Metrics/{phase_type}/{name}"] = metric(pred, label, inputs, mask) + return metrics_result + + def train_dataloader(self): + """Training dataloader""" + return self.train_loader + + def val_dataloader(self): + """Validation dataloader""" + return self.val_loader + + +def _get_model( + num_sources, + enc_kernel_size=16, + enc_num_feats=512, + msk_kernel_size=3, + msk_num_feats=128, + msk_num_hidden_feats=512, + msk_num_layers=8, + msk_num_stacks=3, + msk_activate="relu", +): + model = torchaudio.models.ConvTasNet( + num_sources=num_sources, + enc_kernel_size=enc_kernel_size, + enc_num_feats=enc_num_feats, + msk_kernel_size=msk_kernel_size, + msk_num_feats=msk_num_feats, + msk_num_hidden_feats=msk_num_hidden_feats, + msk_num_layers=msk_num_layers, + msk_num_stacks=msk_num_stacks, + msk_activate=msk_activate, + ) + return model + + +def _get_dataloader( + dataset_type: str, + root_dir: Union[str, Path], + num_speakers: int = 2, + sample_rate: int = 8000, + batch_size: int = 6, + num_workers: int = 4, + librimix_task: Optional[str] = None, + librimix_tr_split: Optional[str] = None, +) -> Tuple[DataLoader]: + """Get dataloaders for training, validation, and testing. + + Args: + dataset_type (str): the dataset to use. + root_dir (str or Path): the root directory of the dataset. + num_speakers (int, optional): the number of speakers in the mixture. (Default: 2) + sample_rate (int, optional): the sample rate of the audio. (Default: 8000) + batch_size (int, optional): the batch size of the dataset. (Default: 6) + num_workers (int, optional): the number of workers for each dataloader. (Default: 4) + librimix_task (str or None, optional): the task in LibriMix dataset. + librimix_tr_split (str or None, optional): the training split in LibriMix dataset. + + Returns: + tuple: (train_loader, valid_loader, eval_loader) + """ + train_dataset, valid_dataset, eval_dataset = dataset_utils.get_dataset( + dataset_type, root_dir, num_speakers, sample_rate, librimix_task, librimix_tr_split + ) + train_collate_fn = dataset_utils.get_collate_fn( + dataset_type, mode='train', sample_rate=sample_rate, duration=3 + ) + + test_collate_fn = dataset_utils.get_collate_fn(dataset_type, mode='test', sample_rate=sample_rate) + + train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + collate_fn=train_collate_fn, + num_workers=num_workers, + drop_last=True, + ) + valid_loader = DataLoader( + valid_dataset, + batch_size=batch_size, + collate_fn=test_collate_fn, + num_workers=num_workers, + drop_last=True, + ) + eval_loader = DataLoader( + eval_dataset, + batch_size=batch_size, + collate_fn=test_collate_fn, + num_workers=num_workers, + ) + return train_loader, valid_loader, eval_loader + + +def cli_main(): + parser = ArgumentParser() + parser.add_argument("--batch-size", default=6, type=int) + parser.add_argument("--dataset", default="librimix", type=str, choices=["wsj0-mix", "librimix"]) + parser.add_argument( + "--root-dir", + type=Path, + help="The path to the directory where the directory ``Libri2Mix`` or ``Libri3Mix`` is stored.", + ) + parser.add_argument( + "--librimix-tr-split", + default="train-360", + choices=["train-360", "train-100"], + help="The training partition of librimix dataset. (default: ``train-360``)", + ) + parser.add_argument( + "--librimix-task", + default="sep_clean", + type=str, + choices=["sep_clean", "sep_noisy", "enh_single", "enh_both"], + help="The task to perform (separation or enhancement, noisy or clean). (default: ``sep_clean``)", + ) + parser.add_argument( + "--num-speakers", default=2, type=int, help="The number of speakers in the mixture. (default: 2)" + ) + parser.add_argument( + "--sample-rate", + default=8000, + type=int, + help="Sample rate of audio files in the given dataset. (default: 8000)", + ) + parser.add_argument( + "--exp-dir", + default=Path("./exp"), + type=Path, + help="The directory to save checkpoints and logs." + ) + parser.add_argument( + "--epochs", + metavar="NUM_EPOCHS", + default=200, + type=int, + help="The number of epochs to train. (default: 200)", + ) + parser.add_argument( + "--learning-rate", + default=1e-3, + type=float, + help="Initial learning rate. (default: 1e-3)", + ) + parser.add_argument( + "--num-gpu", + default=1, + type=int, + help="The number of GPUs for training. (default: 1)", + ) + parser.add_argument( + "--num-workers", + default=4, + type=int, + help="The number of workers for dataloader. (default: 4)", + ) + + args = parser.parse_args() + + model = _get_model(num_sources=args.num_speakers) + + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=5 + ) + train_loader, valid_loader, eval_loader = _get_dataloader( + args.dataset, + args.root_dir, + args.num_speakers, + args.sample_rate, + args.batch_size, + args.num_workers, + args.librimix_task, + args.librimix_tr_split, + ) + loss = si_sdr_loss + metric_dict = { + "sdri": sdri_metric, + "sisdri": sisdri_metric, + } + model = ConvTasNetModule( + model=model, + train_loader=train_loader, + val_loader=valid_loader, + loss=loss, + optim=optimizer, + metrics=metric_dict, + lr_scheduler=lr_scheduler, + ) + checkpoint_dir = args.exp_dir / "checkpoints" + checkpoint = ModelCheckpoint( + checkpoint_dir, + monitor="Losses/val_loss", + mode="min", + save_top_k=5, + save_weights_only=True, + verbose=True + ) + callbacks = [ + checkpoint, + EarlyStopping(monitor="Losses/val_loss", mode="min", patience=30, verbose=True), + ] + trainer = Trainer( + default_root_dir=args.exp_dir, + max_epochs=args.epochs, + gpus=args.num_gpu, + accelerator="ddp", + plugins=DDPPlugin(find_unused_parameters=False), # make sure there is no unused params + limit_train_batches=1.0, # Useful for fast experiment + gradient_clip_val=5.0, + callbacks=callbacks, + ) + trainer.fit(model) + model.load_from_checkpoint(checkpoint.best_model_path) + state_dict = torch.load(checkpoint.best_model_path, map_location="cpu") + state_dict = {k.replace("model.", ""): v for k, v in state_dict["state_dict"].items()} + torch.save(state_dict, args.exp_dir / "best_model.pth") + trainer.test(model, eval_loader) + + +if __name__ == "__main__": + cli_main() diff --git a/examples/source_separation/train.py b/examples/source_separation/train.py new file mode 100644 index 0000000000000000000000000000000000000000..a969d33fb822975ac8f4ad00b4c82443adf0b2ca --- /dev/null +++ b/examples/source_separation/train.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +"""Launch souce separation training. + +This script runs training in Distributed Data Parallel (DDP) framework and has two major +operation modes. This behavior depends on if `--worker-id` argument is given or not. + +1. (`--worker-id` is not given) Launchs worker subprocesses that performs the actual training. +2. (`--worker-id` is given) Performs the training as a part of distributed training. + +When launching the script without any distributed trainig parameters (operation mode 1), +this script will check the number of GPUs available on the local system and spawns the same +number of training subprocesses (as operaiton mode 2). You can reduce the number of GPUs with +`--num-workers`. If there is no GPU available, only one subprocess is launched. + +When launching the script as a worker process of a distributed training, you need to configure +the coordination of the workers. +""" +import sys +import logging +import argparse +import subprocess + +import torch + +from utils import dist_utils + +_LG = dist_utils.getLogger(__name__) + + +def _parse_args(args=None): + max_world_size = torch.cuda.device_count() or 1 + + parser = argparse.ArgumentParser( + description=__doc__, + ) + parser.add_argument("--debug", action="store_true", help="Enable debug log") + group = parser.add_argument_group("Distributed Training") + group.add_argument( + "--worker-id", + type=int, + help=( + "If not provided, the launched process serves as a master process of " + "single-node, multi-worker training and spawns the worker subprocesses. " + "If provided, the launched process serves as a worker process, which " + "performs the actual training. The valid value is [0, --num-workers)." + ), + ) + group.add_argument( + "--device-id", + type=int, + help="The CUDA device ID. Allowed only when --worker-id is provided.", + ) + group.add_argument( + "--num-workers", + type=int, + default=max_world_size, + help=( + "The size of distributed trainig workers. " + "If launching a training as single-node, multi-worker training, " + "(i.e. --worker-id is not provided) then this value should not exceed " + "the number of available GPUs. " + "If launching the training process as a multi-node, multi-gpu training, " + "(i.e. --worker-id is provided) then the value has to match " + f"the number of workers across nodes. (default: {max_world_size})" + ), + ) + group.add_argument( + "--sync-protocol", + type=str, + default="env://", + help=( + "Synchronization protocol for distributed training. " + "This value is passed as `init_method` argument of " + "`torch.distributed.init_process_group` function." + 'If you are using `"env://"`, you can additionally configure ' + 'environment variables "MASTER_ADDR" and "MASTER_PORT". ' + 'If you are using `"file://..."`, then the process has to have ' + "the access to the designated file. " + "See the documentation for `torch.distributed` for the detail. " + 'If you are running the training in a single node, `"env://"` ' + "should do. If you are running the training in multiple nodes, " + "you need to provide the file location where all the nodes have " + 'access, using `"file://..."` protocol. (default: "env://")' + ), + ) + group.add_argument( + "--random-seed", + type=int, + help="Set random seed value. (default: None)", + ) + parser.add_argument( + "rest", nargs=argparse.REMAINDER, help="Model-specific arguments." + ) + namespace = parser.parse_args(args) + if namespace.worker_id is None: + if namespace.device_id is not None: + raise ValueError( + "`--device-id` cannot be provided when runing as master process." + ) + if namespace.num_workers > max_world_size: + raise ValueError( + "--num-workers ({num_workers}) cannot exceed {device_count}." + ) + if namespace.rest[:1] == ["--"]: + namespace.rest = namespace.rest[1:] + return namespace + + +def _main(cli_args): + args = _parse_args(cli_args) + + if any(arg in ["--help", "-h"] for arg in args.rest): + _run_training(args.rest) + + _init_logger(args.worker_id, args.debug) + if args.worker_id is None: + _run_training_subprocesses(args.num_workers, cli_args) + else: + dist_utils.setup_distributed( + world_size=args.num_workers, + rank=args.worker_id, + local_rank=args.device_id, + backend='nccl' if torch.cuda.is_available() else 'gloo', + init_method=args.sync_protocol, + ) + if args.random_seed is not None: + torch.manual_seed(args.random_seed) + if torch.cuda.is_available(): + torch.cuda.set_device(args.device_id) + _LG.info("CUDA device set to %s", args.device_id) + _run_training(args.rest) + + +def _run_training_subprocesses(num_workers, original_args): + workers = [] + _LG.info("Spawning %s workers", num_workers) + for i in range(num_workers): + worker_arg = ["--worker-id", f"{i}", "--num-workers", f"{num_workers}"] + device_arg = ["--device-id", f"{i}"] if torch.cuda.is_available() else [] + command = ( + [sys.executable, "-u", sys.argv[0]] + + worker_arg + + device_arg + + original_args + ) + _LG.info("Launching worker %s: `%s`", i, " ".join(command)) + worker = subprocess.Popen(command) + workers.append(worker) + + num_failed = 0 + for worker in workers: + worker.wait() + if worker.returncode != 0: + num_failed += 1 + sys.exit(num_failed) + + +def _run_training(args): + import conv_tasnet.train + + conv_tasnet.train.train(args) + + +def _init_logger(rank=None, debug=False): + worker_fmt = "[master]" if rank is None else f"[worker {rank:2d}]" + message_fmt = ( + "%(levelname)5s: %(funcName)10s: %(message)s" if debug else "%(message)s" + ) + logging.basicConfig( + level=logging.DEBUG if debug else logging.INFO, + format=f"%(asctime)s: {worker_fmt} {message_fmt}", + ) + + +if __name__ == "__main__": + _main(sys.argv[1:]) diff --git a/examples/source_separation/utils/__init__.py b/examples/source_separation/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0d9f0206e3f5eb8a91a7cc8653815282887262c8 --- /dev/null +++ b/examples/source_separation/utils/__init__.py @@ -0,0 +1,7 @@ +from . import ( + dataset, + dist_utils, + metrics, +) + +__all__ = ['dataset', 'dist_utils', 'metrics'] diff --git a/examples/source_separation/utils/dataset/__init__.py b/examples/source_separation/utils/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..45eb6a785ddc603629deca3fd46b57ade8c0dabc --- /dev/null +++ b/examples/source_separation/utils/dataset/__init__.py @@ -0,0 +1,3 @@ +from . import utils, wsj0mix + +__all__ = ['utils', 'wsj0mix'] diff --git a/examples/source_separation/utils/dataset/utils.py b/examples/source_separation/utils/dataset/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a53cc1231acf43537bec245e8cb6dea5856492d2 --- /dev/null +++ b/examples/source_separation/utils/dataset/utils.py @@ -0,0 +1,89 @@ +from typing import List +from functools import partial +from collections import namedtuple + +from torchaudio.datasets import LibriMix +import torch + +from . import wsj0mix + +Batch = namedtuple("Batch", ["mix", "src", "mask"]) + + +def get_dataset(dataset_type, root_dir, num_speakers, sample_rate, task=None, librimix_tr_split=None): + if dataset_type == "wsj0mix": + train = wsj0mix.WSJ0Mix(root_dir / "tr", num_speakers, sample_rate) + validation = wsj0mix.WSJ0Mix(root_dir / "cv", num_speakers, sample_rate) + evaluation = wsj0mix.WSJ0Mix(root_dir / "tt", num_speakers, sample_rate) + elif dataset_type == "librimix": + train = LibriMix(root_dir, librimix_tr_split, num_speakers, sample_rate, task) + validation = LibriMix(root_dir, "dev", num_speakers, sample_rate, task) + evaluation = LibriMix(root_dir, "test", num_speakers, sample_rate, task) + else: + raise ValueError(f"Unexpected dataset: {dataset_type}") + return train, validation, evaluation + + +def _fix_num_frames(sample: wsj0mix.SampleType, target_num_frames: int, sample_rate: int, random_start=False): + """Ensure waveform has exact number of frames by slicing or padding""" + mix = sample[1] # [1, time] + src = torch.cat(sample[2], 0) # [num_sources, time] + + num_channels, num_frames = src.shape + num_seconds = torch.div(num_frames, sample_rate, rounding_mode='floor') + target_seconds = torch.div(target_num_frames, sample_rate, rounding_mode='floor') + if num_frames >= target_num_frames: + if random_start and num_frames > target_num_frames: + start_frame = torch.randint(num_seconds - target_seconds + 1, [1]) * sample_rate + mix = mix[:, start_frame:] + src = src[:, start_frame:] + mix = mix[:, :target_num_frames] + src = src[:, :target_num_frames] + mask = torch.ones_like(mix) + else: + num_padding = target_num_frames - num_frames + pad = torch.zeros([1, num_padding], dtype=mix.dtype, device=mix.device) + mix = torch.cat([mix, pad], 1) + src = torch.cat([src, pad.expand(num_channels, -1)], 1) + mask = torch.ones_like(mix) + mask[..., num_frames:] = 0 + return mix, src, mask + + +def collate_fn_wsj0mix_train(samples: List[wsj0mix.SampleType], sample_rate, duration): + target_num_frames = int(duration * sample_rate) + + mixes, srcs, masks = [], [], [] + for sample in samples: + mix, src, mask = _fix_num_frames(sample, target_num_frames, sample_rate, random_start=True) + + mixes.append(mix) + srcs.append(src) + masks.append(mask) + + return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0), torch.stack(masks, 0)) + + +def collate_fn_wsj0mix_test(samples: List[wsj0mix.SampleType], sample_rate): + max_num_frames = max(s[1].shape[-1] for s in samples) + + mixes, srcs, masks = [], [], [] + for sample in samples: + mix, src, mask = _fix_num_frames(sample, max_num_frames, sample_rate, random_start=False) + + mixes.append(mix) + srcs.append(src) + masks.append(mask) + + return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0), torch.stack(masks, 0)) + + +def get_collate_fn(dataset_type, mode, sample_rate=None, duration=4): + assert mode in ["train", "test"] + if dataset_type in ["wsj0mix", "librimix"]: + if mode == 'train': + if sample_rate is None: + raise ValueError("sample_rate is not given.") + return partial(collate_fn_wsj0mix_train, sample_rate=sample_rate, duration=duration) + return partial(collate_fn_wsj0mix_test, sample_rate=sample_rate) + raise ValueError(f"Unexpected dataset: {dataset_type}") diff --git a/examples/source_separation/utils/dataset/wsj0mix.py b/examples/source_separation/utils/dataset/wsj0mix.py new file mode 100644 index 0000000000000000000000000000000000000000..89d59f3ac3d6855d698d32a3de203125f45bba04 --- /dev/null +++ b/examples/source_separation/utils/dataset/wsj0mix.py @@ -0,0 +1,70 @@ +from pathlib import Path +from typing import Union, Tuple, List + +import torch +from torch.utils.data import Dataset + +import torchaudio + +SampleType = Tuple[int, torch.Tensor, List[torch.Tensor]] + + +class WSJ0Mix(Dataset): + """Create a Dataset for wsj0-mix. + + Args: + root (str or Path): Path to the directory where the dataset is found. + num_speakers (int): The number of speakers, which determines the directories + to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect + N source audios. + sample_rate (int): Expected sample rate of audio files. If any of the audio has a + different sample rate, raises ``ValueError``. + audio_ext (str, optional): The extension of audio files to find. (default: ".wav") + """ + def __init__( + self, + root: Union[str, Path], + num_speakers: int, + sample_rate: int, + audio_ext: str = ".wav", + ): + self.root = Path(root) + self.sample_rate = sample_rate + self.mix_dir = (self.root / "mix").resolve() + self.src_dirs = [(self.root / f"s{i+1}").resolve() for i in range(num_speakers)] + + self.files = [p.name for p in self.mix_dir.glob(f"*{audio_ext}")] + self.files.sort() + + def _load_audio(self, path) -> torch.Tensor: + waveform, sample_rate = torchaudio.load(path) + if sample_rate != self.sample_rate: + raise ValueError( + f"The dataset contains audio file of sample rate {sample_rate}, " + f"but the requested sample rate is {self.sample_rate}." + ) + return waveform + + def _load_sample(self, filename) -> SampleType: + mixed = self._load_audio(str(self.mix_dir / filename)) + srcs = [] + for i, dir_ in enumerate(self.src_dirs): + src = self._load_audio(str(dir_ / filename)) + if mixed.shape != src.shape: + raise ValueError( + f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}" + ) + srcs.append(src) + return self.sample_rate, mixed, srcs + + def __len__(self) -> int: + return len(self.files) + + def __getitem__(self, key: int) -> SampleType: + """Load the n-th sample from the dataset. + Args: + key (int): The index of the sample to be loaded + Returns: + tuple: ``(sample_rate, mix_waveform, list_of_source_waveforms)`` + """ + return self._load_sample(self.files[key]) diff --git a/examples/source_separation/utils/dist_utils.py b/examples/source_separation/utils/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..380358b87b6bad885697ed0ea4e7e3629b47e9c2 --- /dev/null +++ b/examples/source_separation/utils/dist_utils.py @@ -0,0 +1,86 @@ +import os +import csv +import types +import logging + +import torch +import torch.distributed as dist + + +def _info_on_master(self, *args, **kwargs): + if dist.get_rank() == 0: + self.info(*args, **kwargs) + + +def getLogger(name): + """Get logging.Logger module with additional ``info_on_master`` method.""" + logger = logging.getLogger(name) + logger.info_on_master = types.MethodType(_info_on_master, logger) + return logger + + +_LG = getLogger(__name__) + + +def setup_distributed( + world_size, rank, local_rank, backend="nccl", init_method="env://" +): + """Perform env setup and initialization for distributed training""" + if init_method == "env://": + _set_env_vars(world_size, rank, local_rank) + if world_size > 1 and "OMP_NUM_THREADS" not in os.environ: + _LG.info("Setting OMP_NUM_THREADS == 1") + os.environ["OMP_NUM_THREADS"] = "1" + params = { + "backend": backend, + "init_method": init_method, + "world_size": world_size, + "rank": rank, + } + _LG.info("Initializing distributed process group with %s", params) + dist.init_process_group(**params) + _LG.info("Initialized distributed process group.") + + +def _set_env_vars(world_size, rank, local_rank): + for key, default in [("MASTER_ADDR", "127.0.0.1"), ("MASTER_PORT", "29500")]: + if key not in os.environ: + os.environ[key] = default + + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(local_rank) + + +def save_on_master(path, obj): + if dist.get_rank() == 0: + _LG.info("Saving %s", path) + torch.save(obj, path) + + +def write_csv_on_master(path, *rows): + if dist.get_rank() == 0: + with open(path, "a", newline="") as fileobj: + writer = csv.writer(fileobj) + for row in rows: + writer.writerow(row) + + +def synchronize_params(path, device, *modules): + if dist.get_world_size() < 2: + return + rank = dist.get_rank() + if rank == 0: + _LG.info("[Parameter Sync]: Saving parameters to a temp file...") + torch.save({f"{i}": m.state_dict() for i, m in enumerate(modules)}, path) + dist.barrier() + if rank != 0: + _LG.info("[Parameter Sync]: Loading parameters...") + data = torch.load(path, map_location=device) + for i, m in enumerate(modules): + m.load_state_dict(data[f"{i}"]) + dist.barrier() + if rank == 0: + _LG.info("[Parameter Sync]: Removing the temp file...") + os.remove(path) + _LG.info_on_master("[Parameter Sync]: Complete.") diff --git a/examples/source_separation/utils/metrics.py b/examples/source_separation/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..03859aef2ba13f7092b29cbc854f4c6dce483f9e --- /dev/null +++ b/examples/source_separation/utils/metrics.py @@ -0,0 +1,204 @@ +import math +from typing import Optional +from itertools import permutations + +import torch + + +def sdr( + estimate: torch.Tensor, + reference: torch.Tensor, + mask: Optional[torch.Tensor] = None, + epsilon: float = 1e-8 +) -> torch.Tensor: + """Computes source-to-distortion ratio. + + 1. scale the reference signal with power(s_est * s_ref) / powr(s_ref * s_ref) + 2. compute SNR between adjusted estimate and reference. + + Args: + estimate (torch.Tensor): Estimtaed signal. + Shape: [batch, speakers (can be 1), time frame] + reference (torch.Tensor): Reference signal. + Shape: [batch, speakers, time frame] + mask (torch.Tensor or None, optional): Binary mask to indicate padded value (0) or valid value (1). + Shape: [batch, 1, time frame] + epsilon (float, optional): constant value used to stabilize division. + + Returns: + torch.Tensor: scale-invariant source-to-distortion ratio. + Shape: [batch, speaker] + + References: + - Single-channel multi-speaker separation using deep clustering + Y. Isik, J. Le Roux, Z. Chen, S. Watanabe, and J. R. Hershey, + - Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation + Luo, Yi and Mesgarani, Nima + https://arxiv.org/abs/1809.07454 + + Notes: + This function is tested to produce the exact same result as + https://github.com/naplab/Conv-TasNet/blob/e66d82a8f956a69749ec8a4ae382217faa097c5c/utility/sdr.py#L34-L56 + """ + reference_pow = reference.pow(2).mean(axis=2, keepdim=True) + mix_pow = (estimate * reference).mean(axis=2, keepdim=True) + scale = mix_pow / (reference_pow + epsilon) + + reference = scale * reference + error = estimate - reference + + reference_pow = reference.pow(2) + error_pow = error.pow(2) + + if mask is None: + reference_pow = reference_pow.mean(axis=2) + error_pow = error_pow.mean(axis=2) + else: + denom = mask.sum(axis=2) + reference_pow = (mask * reference_pow).sum(axis=2) / denom + error_pow = (mask * error_pow).sum(axis=2) / denom + + return 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow) + + +class PIT(torch.nn.Module): + """Applies utterance-level speaker permutation + + Computes the maxium possible value of the given utility function + over the permutations of the speakers. + + Args: + utility_func (function): + Function that computes the utility (opposite of loss) with signature of + (extimate: torch.Tensor, reference: torch.Tensor) -> torch.Tensor + where input Tensors are shape of [batch, speakers, frame] and + the output Tensor is shape of [batch, speakers]. + + References: + - Multi-talker Speech Separation with Utterance-level Permutation Invariant Training of + Deep Recurrent Neural Networks + Morten Kolbæk, Dong Yu, Zheng-Hua Tan and Jesper Jensen + https://arxiv.org/abs/1703.06284 + """ + + def __init__(self, utility_func): + super().__init__() + self.utility_func = utility_func + + def forward( + self, + estimate: torch.Tensor, + reference: torch.Tensor, + mask: Optional[torch.Tensor] = None, + epsilon: float = 1e-8 + ) -> torch.Tensor: + """Compute utterance-level PIT Loss + + Args: + estimate (torch.Tensor): Estimated source signals. + Shape: [bacth, speakers, time frame] + reference (torch.Tensor): Reference (original) source signals. + Shape: [batch, speakers, time frame] + mask (torch.Tensor or None, optional): Binary mask to indicate padded value (0) or valid value (1). + Shape: [batch, 1, time frame] + epsilon (float, optional): constant value used to stabilize division. + + Returns: + torch.Tensor: Maximum criterion over the speaker permutation. + Shape: [batch, ] + """ + assert estimate.shape == reference.shape + + batch_size, num_speakers = reference.shape[:2] + num_permute = math.factorial(num_speakers) + + util_mat = torch.zeros( + batch_size, num_permute, dtype=estimate.dtype, device=estimate.device + ) + for i, idx in enumerate(permutations(range(num_speakers))): + util = self.utility_func(estimate, reference[:, idx, :], mask=mask, epsilon=epsilon) + util_mat[:, i] = util.mean(dim=1) # take the average over speaker dimension + return util_mat.max(dim=1).values + + +_sdr_pit = PIT(utility_func=sdr) + + +def sdr_pit( + estimate: torch.Tensor, + reference: torch.Tensor, + mask: Optional[torch.Tensor] = None, + epsilon: float = 1e-8): + """Computes scale-invariant source-to-distortion ratio. + + 1. adjust both estimate and reference to have 0-mean + 2. scale the reference signal with power(s_est * s_ref) / powr(s_ref * s_ref) + 3. compute SNR between adjusted estimate and reference. + + Args: + estimate (torch.Tensor): Estimtaed signal. + Shape: [batch, speakers (can be 1), time frame] + reference (torch.Tensor): Reference signal. + Shape: [batch, speakers, time frame] + mask (torch.Tensor or None, optional): Binary mask to indicate padded value (0) or valid value (1). + Shape: [batch, 1, time frame] + epsilon (float, optional): constant value used to stabilize division. + + Returns: + torch.Tensor: scale-invariant source-to-distortion ratio. + Shape: [batch, speaker] + + References: + - Single-channel multi-speaker separation using deep clustering + Y. Isik, J. Le Roux, Z. Chen, S. Watanabe, and J. R. Hershey, + - Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation + Luo, Yi and Mesgarani, Nima + https://arxiv.org/abs/1809.07454 + + Notes: + This function is tested to produce the exact same result as the reference implementation, + *when the inputs have 0-mean* + https://github.com/naplab/Conv-TasNet/blob/e66d82a8f956a69749ec8a4ae382217faa097c5c/utility/sdr.py#L107-L153 + """ + return _sdr_pit(estimate, reference, mask, epsilon) + + +def sdri( + estimate: torch.Tensor, + reference: torch.Tensor, + mix: torch.Tensor, + mask: Optional[torch.Tensor] = None, + epsilon: float = 1e-8, +) -> torch.Tensor: + """Compute the improvement of SDR (SDRi). + + This function compute how much SDR is improved if the estimation is changed from + the original mixture signal to the actual estimated source signals. That is, + ``SDR(estimate, reference) - SDR(mix, reference)``. + + For computing ``SDR(estimate, reference)``, PIT (permutation invariant training) is applied, + so that best combination of sources between the reference signals and the esimate signals + are picked. + + Args: + estimate (torch.Tensor): Estimated source signals. + Shape: [batch, speakers, time frame] + reference (torch.Tensor): Reference (original) source signals. + Shape: [batch, speakers, time frame] + mix (torch.Tensor): Mixed souce signals, from which the setimated signals were generated. + Shape: [batch, speakers == 1, time frame] + mask (torch.Tensor or None, optional): Binary mask to indicate padded value (0) or valid value (1). + Shape: [batch, 1, time frame] + epsilon (float, optional): constant value used to stabilize division. + + Returns: + torch.Tensor: Improved SDR. Shape: [batch, ] + + References: + - Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation + Luo, Yi and Mesgarani, Nima + https://arxiv.org/abs/1809.07454 + """ + sdr_ = sdr_pit(estimate, reference, mask=mask, epsilon=epsilon) # [batch, ] + base_sdr = sdr(mix, reference, mask=mask, epsilon=epsilon) # [batch, speaker] + return sdr_ - base_sdr.mean(dim=1) diff --git a/examples/test/__init__.py b/examples/test/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/test/test_interactive_asr.py b/examples/test/test_interactive_asr.py new file mode 100644 index 0000000000000000000000000000000000000000..40867a0f96c429ad7997687805398474a736a4de --- /dev/null +++ b/examples/test/test_interactive_asr.py @@ -0,0 +1,105 @@ +import argparse +import logging +import os +import unittest + +from interactive_asr.utils import setup_asr, transcribe_file + + +class ASRTest(unittest.TestCase): + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + + arguments_dict = { + "path": "/scratch/jamarshon/downloads/model.pt", + "input_file": "/scratch/jamarshon/audio/examples/interactive_asr/data/sample.wav", + "data": "/scratch/jamarshon/downloads", + "user_dir": "/scratch/jamarshon/fairseq-py/examples/speech_recognition", + "no_progress_bar": False, + "log_interval": 1000, + "log_format": None, + "tensorboard_logdir": "", + "tbmf_wrapper": False, + "seed": 1, + "cpu": True, + "fp16": False, + "memory_efficient_fp16": False, + "fp16_init_scale": 128, + "fp16_scale_window": None, + "fp16_scale_tolerance": 0.0, + "min_loss_scale": 0.0001, + "threshold_loss_scale": None, + "criterion": "cross_entropy", + "tokenizer": None, + "bpe": None, + "optimizer": "nag", + "lr_scheduler": "fixed", + "task": "speech_recognition", + "num_workers": 0, + "skip_invalid_size_inputs_valid_test": False, + "max_tokens": 10000000, + "max_sentences": None, + "required_batch_size_multiple": 8, + "dataset_impl": None, + "gen_subset": "test", + "num_shards": 1, + "shard_id": 0, + "remove_bpe": None, + "quiet": False, + "model_overrides": "{}", + "results_path": None, + "beam": 40, + "nbest": 1, + "max_len_a": 0, + "max_len_b": 200, + "min_len": 1, + "match_source_len": False, + "no_early_stop": False, + "unnormalized": False, + "no_beamable_mm": False, + "lenpen": 1, + "unkpen": 0, + "replace_unk": None, + "sacrebleu": False, + "score_reference": False, + "prefix_size": 0, + "no_repeat_ngram_size": 0, + "sampling": False, + "sampling_topk": -1, + "sampling_topp": -1.0, + "temperature": 1.0, + "diverse_beam_groups": -1, + "diverse_beam_strength": 0.5, + "print_alignment": False, + "ctc": False, + "rnnt": False, + "kspmodel": None, + "wfstlm": None, + "rnnt_decoding_type": "greedy", + "lm_weight": 0.2, + "rnnt_len_penalty": -0.5, + "momentum": 0.99, + "weight_decay": 0.0, + "force_anneal": None, + "lr_shrink": 0.1, + "warmup_updates": 0, + } + + arguments_dict["path"] = os.environ.get("ASR_MODEL_PATH", None) + arguments_dict["input_file"] = os.environ.get("ASR_INPUT_FILE", None) + arguments_dict["data"] = os.environ.get("ASR_DATA_PATH", None) + arguments_dict["user_dir"] = os.environ.get("ASR_USER_DIR", None) + args = argparse.Namespace(**arguments_dict) + + def test_transcribe_file(self): + task, generator, models, sp, tgt_dict = setup_asr(self.args, self.logger) + _, transcription = transcribe_file( + self.args, task, generator, models, sp, tgt_dict + ) + + expected_transcription = [["THE QUICK BROWN FOX JUMPS OVER THE LAZY DOG"]] + self.assertEqual(transcription, expected_transcription, msg=str(transcription)) + + +if __name__ == "__main__": + unittest.main() diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000000000000000000000000000000000000..068f0918a38581d97f9c4f0230e1838d107ead4a --- /dev/null +++ b/mypy.ini @@ -0,0 +1,3 @@ +[mypy] +allow_redefinition = True +ignore_missing_imports = True diff --git a/packaging/README.md b/packaging/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f3ee62e62c35259f9c407bfa7a7090ffb743af74 --- /dev/null +++ b/packaging/README.md @@ -0,0 +1,95 @@ +# Building torchaudio packages for release + + ## Anaconda packages + + ### Linux + + ```bash +docker run -it --ipc=host --rm -v $(pwd):/remote soumith/conda-cuda bash +cd remote +PYTHON_VERSION=3.7 packaging/build_conda.sh +``` + +To install bz2, +```bash +cd /opt/conda/conda-bld/linux-64/ +# install dependencies +conda install pytorch-cpu=1.1.0 +conda install sox +# install torchaudio +conda install /opt/conda/conda-bld/linux-64/torchaudio-cpu-0.2.0-py27_1.tar.bz2 +``` + +To upload bz2, +```bash +anaconda upload -u pytorch /opt/conda/conda-bld/linux-64/torchaudio*.bz2 +``` + + ### OSX + + ```bash +# create a fresh anaconda environment / install and activate it +PYTHON_VERSION=3.7 packaging/build_conda.sh +``` + +To install bz2, +```bash +cd /Users/jamarshon/anaconda3/conda-bld/osx-64/ +# activate conda env (e.g +conda info --envs +conda activate /Users/jamarshon/minconda_wheel_env_tmp/envs/env2.7 +# install dependencies +conda install pytorch-cpu=1.1.0 +conda install sox +# install torchaudio +# and then try installing (e.g +conda install /Users/jamarshon/anaconda3/conda-bld/osx-64/torchaudio-0.2.0-py27_1.tar.bz2 +``` + +To upload bz2, +```bash +anaconda upload -u pytorch /Users/jamarshon/anaconda3/conda-bld/osx-64/torchaudio*.bz2 +``` + + ## Wheels + + ### Linux + + ```bash +nvidia-docker run -it --ipc=host --rm -v $(pwd):/remote soumith/manylinux-cuda90:latest bash +cd remote +PYTHON_VERSION=3.7 packaging/build_wheel.sh +``` + +To install wheels, +```bash +cd ../cpu +/opt/python/cp35-cp35m/bin/pip install torchaudio-0.2-cp35-cp35m-linux_x86_64.whl +``` + +To upload wheels, +```bash +cd ../cpu +/opt/python/cp35-cp35m/bin/pip install twine +/opt/python/cp35-cp35m/bin/twine upload *.whl +``` + + ### OSX + + ```bash +PYTHON_VERSION=3.7 packaging/build_wheel.sh +``` + +To install wheels, +```bash +cd ~/torchaudio_wheels +conda activate /Users/jamarshon/minconda_wheel_env_tmp/envs/env2.7 +pip install torchaudio-0.2-cp27-cp27m-macosx_10_6_x86_64.whl +``` + +To upload wheels, +```bash +pip install twine +cd ~/torchaudio_wheels +twine upload *.whl +``` diff --git a/packaging/build_conda.sh b/packaging/build_conda.sh new file mode 100644 index 0000000000000000000000000000000000000000..0e6aae53b88be674188997b794f80df734742557 --- /dev/null +++ b/packaging/build_conda.sh @@ -0,0 +1,14 @@ +#!/bin/bash +set -ex + +script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +. "$script_dir/pkg_helpers.bash" + +export BUILD_TYPE="conda" +setup_env 0.10.0 +export SOURCE_ROOT_DIR="$PWD" +setup_conda_pytorch_constraint +setup_conda_cudatoolkit_constraint +setup_visual_studio_constraint +# nvidia channel included for cudatoolkit >= 11 +conda build -c defaults -c nvidia $CONDA_CHANNEL_FLAGS --no-anaconda-upload --python "$PYTHON_VERSION" packaging/torchaudio diff --git a/packaging/build_wheel.sh b/packaging/build_wheel.sh new file mode 100644 index 0000000000000000000000000000000000000000..0ca814caf17d0b065f5d37fe5a8c617722f98511 --- /dev/null +++ b/packaging/build_wheel.sh @@ -0,0 +1,18 @@ +#!/bin/bash +set -ex + +script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +. "$script_dir/pkg_helpers.bash" + +export BUILD_TYPE="wheel" +setup_env 0.10.0 +setup_wheel_python +pip_install numpy future cmake ninja +setup_pip_pytorch_version +python setup.py clean +if [[ "$OSTYPE" == "msys" ]]; then + python_tag="$(echo "cp$PYTHON_VERSION" | tr -d '.')" + "$script_dir/vc_env_helper.bat" python setup.py bdist_wheel --plat-name win_amd64 --python-tag $python_tag +else + python setup.py bdist_wheel +fi diff --git a/packaging/pkg_helpers.bash b/packaging/pkg_helpers.bash new file mode 100644 index 0000000000000000000000000000000000000000..4babde63d02d75eb542fb1718bc6b8084245719b --- /dev/null +++ b/packaging/pkg_helpers.bash @@ -0,0 +1,302 @@ +# A set of useful bash functions for common functionality we need to do in +# many build scripts + + +# Setup CUDA environment variables, based on CU_VERSION +# +# Inputs: +# CU_VERSION (cpu, cu92, cu100) +# NO_CUDA_PACKAGE (bool) +# BUILD_TYPE (conda, wheel) +# +# Outputs: +# VERSION_SUFFIX (e.g., "") +# PYTORCH_VERSION_SUFFIX (e.g., +cpu) +# WHEEL_DIR (e.g., cu100/) +# CUDA_HOME (e.g., /usr/local/cuda-9.2, respected by torch.utils.cpp_extension) +# USE_CUDA (respected by torchaudio setup.py) +# NVCC_FLAGS (respected by torchaudio setup.py) +# +# Precondition: CUDA versions are installed in their conventional locations in +# /usr/local/cuda-* +# +# NOTE: Why VERSION_SUFFIX versus PYTORCH_VERSION_SUFFIX? If you're building +# a package with CUDA on a platform we support CUDA on, VERSION_SUFFIX == +# PYTORCH_VERSION_SUFFIX and everyone is happy. However, if you are building a +# package with only CPU bits (e.g., torchaudio), then VERSION_SUFFIX is always +# empty, but PYTORCH_VERSION_SUFFIX is +cpu (because that's how you get a CPU +# version of a Python package. But that doesn't apply if you're on OS X, +# since the default CU_VERSION on OS X is cpu. +setup_cuda() { + + # First, compute version suffixes. By default, assume no version suffixes + export VERSION_SUFFIX="" + export PYTORCH_VERSION_SUFFIX="" + export WHEEL_DIR="cpu/" + # Wheel builds need suffixes (but not if they're on OS X, which never has suffix) + if [[ "$BUILD_TYPE" == "wheel" ]] && [[ "$(uname)" != Darwin ]]; then + export PYTORCH_VERSION_SUFFIX="+$CU_VERSION" + # Match the suffix scheme of pytorch, unless this package does not have + # CUDA builds (in which case, use default) + if [[ -z "$NO_CUDA_PACKAGE" ]]; then + export VERSION_SUFFIX="$PYTORCH_VERSION_SUFFIX" + export WHEEL_DIR="$CU_VERSION/" + fi + fi + + # Now work out the CUDA settings + case "$CU_VERSION" in + cu113) + if [[ "$OSTYPE" == "msys" ]]; then + export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.3" + else + export CUDA_HOME=/usr/local/cuda-11.3/ + fi + export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" + ;; + cu112) + if [[ "$OSTYPE" == "msys" ]]; then + export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.2" + else + export CUDA_HOME=/usr/local/cuda-11.2/ + fi + export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" + ;; + cu111) + if [[ "$OSTYPE" == "msys" ]]; then + export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.1" + else + export CUDA_HOME=/usr/local/cuda-11.1/ + fi + export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" + ;; + cu110) + if [[ "$OSTYPE" == "msys" ]]; then + export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.0" + else + export CUDA_HOME=/usr/local/cuda-11.0/ + fi + export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0" + ;; + cu102) + if [[ "$OSTYPE" == "msys" ]]; then + export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.2" + else + export CUDA_HOME=/usr/local/cuda-10.2/ + fi + export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" + ;; + cu101) + if [[ "$OSTYPE" == "msys" ]]; then + export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.1" + else + export CUDA_HOME=/usr/local/cuda-10.1/ + fi + export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" + ;; + cu100) + export CUDA_HOME=/usr/local/cuda-10.0/ + export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" + ;; + cu92) + export CUDA_HOME=/usr/local/cuda-9.2/ + export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0" + ;; + rocm*) + export USE_ROCM=1 + ;; + cpu) + ;; + *) + echo "Unrecognized CU_VERSION=$CU_VERSION" + exit 1 + ;; + esac + if [[ -n "$CUDA_HOME" ]]; then + # Adds nvcc binary to the search path so that CMake's `find_package(CUDA)` will pick the right one + export PATH="$CUDA_HOME/bin:$PATH" + # TODO: Fix Windows CUDA builds + if [[ "$OSTYPE" != "msys" ]]; then + # Force GPU builds on CPU runner, when `torch.cuda.is_available()` returns false + export USE_CUDA=1 + fi + fi +} + +# Populate build version if necessary, and add version suffix +# +# Inputs: +# BUILD_VERSION (e.g., 0.2.0 or empty) +# VERSION_SUFFIX (e.g., +cpu) +# +# Outputs: +# BUILD_VERSION (e.g., 0.2.0.dev20190807+cpu) +# +# Fill BUILD_VERSION if it doesn't exist already with a nightly string +# Usage: setup_build_version 0.2.0 +setup_build_version() { + if [[ -z "$BUILD_VERSION" ]]; then + export BUILD_VERSION="$1.dev$(date "+%Y%m%d")$VERSION_SUFFIX" + else + export BUILD_VERSION="$BUILD_VERSION$VERSION_SUFFIX" + fi +} + +# Set some useful variables for OS X, if applicable +setup_macos() { + if [[ "$(uname)" == Darwin ]]; then + export MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ + fi +} + +# Top-level entry point for things every package will need to do +# +# Usage: setup_env 0.2.0 +setup_env() { + git submodule update --init --recursive + setup_cuda + setup_build_version "$1" + setup_macos +} + +# Function to retry functions that sometimes timeout or have flaky failures +retry () { + $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) +} + +# Inputs: +# PYTHON_VERSION (2.7, 3.5, 3.6, 3.7) +# UNICODE_ABI (bool) +# +# Outputs: +# PATH modified to put correct Python version in PATH +# +# Precondition: If Linux, you are in a soumith/manylinux-cuda* Docker image +setup_wheel_python() { + if [[ "$(uname)" == Darwin || "$OSTYPE" == "msys" ]]; then + eval "$(conda shell.bash hook)" + conda env remove -n "env$PYTHON_VERSION" || true + conda create -yn "env$PYTHON_VERSION" python="$PYTHON_VERSION" + conda activate "env$PYTHON_VERSION" + else + case "$PYTHON_VERSION" in + 2.7) + if [[ -n "$UNICODE_ABI" ]]; then + python_abi=cp27-cp27mu + else + python_abi=cp27-cp27m + fi + ;; + 3.5) python_abi=cp35-cp35m ;; + 3.6) python_abi=cp36-cp36m ;; + 3.7) python_abi=cp37-cp37m ;; + 3.8) python_abi=cp38-cp38 ;; + 3.9) python_abi=cp39-cp39 ;; + *) + echo "Unrecognized PYTHON_VERSION=$PYTHON_VERSION" + exit 1 + ;; + esac + export PATH="/opt/python/$python_abi/bin:$PATH" + fi +} + +# Install with pip a bit more robustly than the default +pip_install() { + retry pip install --progress-bar off "$@" +} + +# Install torch with pip, respecting PYTORCH_VERSION, and record the installed +# version into PYTORCH_VERSION, if applicable +setup_pip_pytorch_version() { + if [[ -z "$PYTORCH_VERSION" ]]; then + # Install latest prerelease version of torch, per our nightlies, consistent + # with the requested cuda version + pip_install --pre torch -f "https://download.pytorch.org/whl/nightly/${WHEEL_DIR}torch_nightly.html" + # CUDA and CPU are ABI compatible on the CPU-only parts, so strip in this case + export PYTORCH_VERSION="$(pip show torch | grep ^Version: | sed 's/Version: *//' | sed 's/+.\+//')" + else + pip_install "torch==$PYTORCH_VERSION$PYTORCH_VERSION_SUFFIX" \ + -f https://download.pytorch.org/whl/torch_stable.html \ + -f "https://download.pytorch.org/whl/${UPLOAD_CHANNEL}/torch_${UPLOAD_CHANNEL}.html" + fi +} + +# Fill PYTORCH_VERSION with the latest conda nightly version, and +# CONDA_CHANNEL_FLAGS with appropriate flags to retrieve these versions +# +# You MUST have populated PYTORCH_VERSION_SUFFIX before hand. +setup_conda_pytorch_constraint() { + CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS}" + if [[ -z "$PYTORCH_VERSION" ]]; then + export CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c pytorch-nightly" + export PYTORCH_VERSION="$(conda search --json 'pytorch[channel=pytorch-nightly]' | python -c "import sys, json, re; print(re.sub(r'\\+.*$', '', json.load(sys.stdin)['pytorch'][-1]['version']))")" + else + export CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c pytorch -c pytorch-test -c pytorch-nightly" + fi + if [[ "$CU_VERSION" == cpu ]]; then + export CONDA_PYTORCH_BUILD_CONSTRAINT="- pytorch==$PYTORCH_VERSION${PYTORCH_VERSION_SUFFIX}" + export CONDA_PYTORCH_CONSTRAINT="- pytorch==$PYTORCH_VERSION" + else + export CONDA_PYTORCH_BUILD_CONSTRAINT="- pytorch==${PYTORCH_VERSION}${PYTORCH_VERSION_SUFFIX}" + export CONDA_PYTORCH_CONSTRAINT="- pytorch==${PYTORCH_VERSION}${PYTORCH_VERSION_SUFFIX}" + fi + # TODO: Remove me later, see https://github.com/pytorch/pytorch/issues/62424 for more details + if [[ "$(uname)" == Darwin ]]; then + # Use less than equal to avoid version conflict in python=3.6 environment + export CONDA_EXTRA_BUILD_CONSTRAINT="- mkl<=2021.2.0" + fi +} + +# Translate CUDA_VERSION into CUDA_CUDATOOLKIT_CONSTRAINT +setup_conda_cudatoolkit_constraint() { + export CONDA_CPUONLY_FEATURE="" + if [[ "$(uname)" == Darwin ]]; then + export CONDA_CUDATOOLKIT_CONSTRAINT="" + else + case "$CU_VERSION" in + cu113) + export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.3,<11.4 # [not osx]" + ;; + cu112) + export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.2,<11.3 # [not osx]" + ;; + cu111) + export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.1,<11.2 # [not osx]" + ;; + cu110) + export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.0,<11.1 # [not osx]" + ;; + cu102) + export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=10.2,<10.3 # [not osx]" + ;; + cu101) + export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=10.1,<10.2 # [not osx]" + ;; + cu100) + export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=10.0,<10.1 # [not osx]" + ;; + cu92) + export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=9.2,<9.3 # [not osx]" + ;; + cpu) + export CONDA_CUDATOOLKIT_CONSTRAINT="" + export CONDA_CPUONLY_FEATURE="- cpuonly" + ;; + *) + echo "Unrecognized CU_VERSION=$CU_VERSION" + exit 1 + ;; + esac + fi +} + +# Build the proper compiler package before building the final package +setup_visual_studio_constraint() { + if [[ "$OSTYPE" == "msys" ]]; then + export VSTOOLCHAIN_PACKAGE=vs2019 + export VSDEVCMD_ARGS='' + conda build $CONDA_CHANNEL_FLAGS --no-anaconda-upload packaging/$VSTOOLCHAIN_PACKAGE + cp packaging/$VSTOOLCHAIN_PACKAGE/conda_build_config.yaml packaging/torchaudio/conda_build_config.yaml + fi +} diff --git a/packaging/torchaudio/bld.bat b/packaging/torchaudio/bld.bat new file mode 100644 index 0000000000000000000000000000000000000000..6b31d4319c2b1d71bd8d3217f20074180876f96d --- /dev/null +++ b/packaging/torchaudio/bld.bat @@ -0,0 +1,5 @@ +@echo off + +set IS_CONDA=1 + +python setup.py install --single-version-externally-managed --record=record.txt diff --git a/packaging/torchaudio/build.sh b/packaging/torchaudio/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..e2b53ddcf495492eba2c5f030efa57a416338930 --- /dev/null +++ b/packaging/torchaudio/build.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -ex + +python setup.py install --single-version-externally-managed --record=record.txt diff --git a/packaging/torchaudio/meta.yaml b/packaging/torchaudio/meta.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c485d3680ca421e24caeeecdd891ba2fdca6071f --- /dev/null +++ b/packaging/torchaudio/meta.yaml @@ -0,0 +1,61 @@ +package: + name: torchaudio + version: "{{ environ.get('BUILD_VERSION', '0.0.0') }}" + +source: + path: "{{ environ.get('SOURCE_ROOT_DIR', '../..') }}" + +requirements: + build: + - {{ compiler('c') }} # [win] + - {{ compiler('cxx') }} # [win] + + host: + - python + - setuptools + - cmake + - ninja + - defaults::numpy >=1.11 + {{ environ.get('CONDA_PYTORCH_BUILD_CONSTRAINT', 'pytorch') }} + {{ environ.get('CONDA_EXTRA_BUILD_CONSTRAINT', '') }} + {{ environ.get('CONDA_CPUONLY_FEATURE', '') }} + {{ environ.get('CONDA_CUDATOOLKIT_CONSTRAINT', '') }} + + run: + - python + - defaults::numpy >=1.11 + {{ environ.get('CONDA_PYTORCH_CONSTRAINT', 'pytorch') }} + {{ environ.get('CONDA_CUDATOOLKIT_CONSTRAINT', '') }} + +build: + string: py{{py}}_{{ environ.get('CU_VERSION', 'cpu') }} + script_env: + - BUILD_VERSION + - USE_CUDA # [not win] + - TORCH_CUDA_ARCH_LIST # [not win] + features: + {{ environ.get('CONDA_CPUONLY_FEATURE', '') }} + +test: + imports: + - torchaudio + - torchaudio.datasets + - torchaudio.kaldi_io + - torchaudio.sox_effects + - torchaudio.transforms + + source_files: + - test + + requires: + - pytest + # Ideally we would test this, but conda doesn't provide librosa + # - librosa >=0.4.3 + - scipy + {{ environ.get('CONDA_CPUONLY_FEATURE', '') }} + +about: + home: https://github.com/pytorch/audio + license: BSD + license_file: LICENSE + summary: 'simple audio I/O for pytorch' diff --git a/packaging/vc_env_helper.bat b/packaging/vc_env_helper.bat new file mode 100644 index 0000000000000000000000000000000000000000..9410135677a4fdc1113d96c5a422583992c688c3 --- /dev/null +++ b/packaging/vc_env_helper.bat @@ -0,0 +1,39 @@ +@echo on + +set VC_VERSION_LOWER=16 +set VC_VERSION_UPPER=17 + +for /f "usebackq tokens=*" %%i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -legacy -products * -version [%VC_VERSION_LOWER%^,%VC_VERSION_UPPER%^) -property installationPath`) do ( + if exist "%%i" if exist "%%i\VC\Auxiliary\Build\vcvarsall.bat" ( + set "VS15INSTALLDIR=%%i" + set "VS15VCVARSALL=%%i\VC\Auxiliary\Build\vcvarsall.bat" + goto vswhere + ) +) + +:vswhere +if "%VSDEVCMD_ARGS%" == "" ( + call "%VS15VCVARSALL%" x64 || exit /b 1 +) else ( + call "%VS15VCVARSALL%" x64 %VSDEVCMD_ARGS% || exit /b 1 +) + +@echo on + +set DISTUTILS_USE_SDK=1 + +set args=%1 +shift +:start +if [%1] == [] goto done +set args=%args% %1 +shift +goto start + +:done +if "%args%" == "" ( + echo Usage: vc_env_helper.bat [command] [args] + echo e.g. vc_env_helper.bat cl /c test.cpp +) + +%args% || exit /b 1 diff --git a/packaging/vs2019/activate.bat b/packaging/vs2019/activate.bat new file mode 100644 index 0000000000000000000000000000000000000000..6f607ba7518e2346e16489195fcdbd111320996c --- /dev/null +++ b/packaging/vs2019/activate.bat @@ -0,0 +1,44 @@ +:: Set env vars that tell distutils to use the compiler that we put on path +SET DISTUTILS_USE_SDK=1 +SET MSSdk=1 + +SET "VS_VERSION=16.0" +SET "VS_MAJOR=16" +SET "VS_YEAR=2019" + +set "MSYS2_ARG_CONV_EXCL=/AI;/AL;/OUT;/out" +set "MSYS2_ENV_CONV_EXCL=CL" + +:: For Python 3.5+, ensure that we link with the dynamic runtime. See +:: http://stevedower.id.au/blog/building-for-python-3-5-part-two/ for more info +set "PY_VCRUNTIME_REDIST=%PREFIX%\\bin\\vcruntime140.dll" + +for /f "usebackq tokens=*" %%i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -legacy -products * -version [16^,17^) -property installationPath`) do ( + if exist "%%i" if exist "%%i\VC\Auxiliary\Build\vcvarsall.bat" ( + set "VSINSTALLDIR=%%i\" + goto :vswhere + ) +) + +:vswhere + +:: Shorten PATH to avoid the `input line too long` error. +SET MyPath=%PATH% + +setlocal EnableDelayedExpansion + +SET TempPath="%MyPath:;=";"%" +SET var= +FOR %%a IN (%TempPath%) DO ( + IF EXIST %%~sa ( + SET "var=!var!;%%~sa" + ) +) + +set "TempPath=!var:~1!" +endlocal & set "PATH=%TempPath%" + +:: Shorten current directory too +FOR %%A IN (.) DO CD "%%~sA" + +:: other things added by install_activate.bat at package build time diff --git a/packaging/vs2019/conda_build_config.yaml b/packaging/vs2019/conda_build_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..358052ec012940bb56778d167bcd69302d255846 --- /dev/null +++ b/packaging/vs2019/conda_build_config.yaml @@ -0,0 +1,24 @@ +blas_impl: + - mkl # [x86_64] +c_compiler: + - vs2019 # [win] +cxx_compiler: + - vs2019 # [win] +python: + - 3.5 + - 3.6 +# This differs from target_platform in that it determines what subdir the compiler +# will target, not what subdir the compiler package will be itself. +# For example, we need a win-64 vs2008_win-32 package, so that we compile win-32 +# code on win-64 miniconda. +cross_compiler_target_platform: + - win-64 # [win] +target_platform: + - win-64 # [win] +vc: + - 14 +zip_keys: + - # [win] + - vc # [win] + - c_compiler # [win] + - cxx_compiler # [win] diff --git a/packaging/vs2019/install_activate.bat b/packaging/vs2019/install_activate.bat new file mode 100644 index 0000000000000000000000000000000000000000..3c38253aa5dea3bdfc9f8cf4027e721376512154 --- /dev/null +++ b/packaging/vs2019/install_activate.bat @@ -0,0 +1,30 @@ +set YEAR=2019 +set VER=16 + +mkdir "%PREFIX%\etc\conda\activate.d" +COPY "%RECIPE_DIR%\activate.bat" "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" + +IF "%cross_compiler_target_platform%" == "win-64" ( + set "target_platform=amd64" + echo SET "CMAKE_GENERATOR=Visual Studio %VER% %YEAR% Win64" >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" + echo pushd "%%VSINSTALLDIR%%" >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" + IF "%VSDEVCMD_ARGS%" == "" ( + echo CALL "VC\Auxiliary\Build\vcvarsall.bat" x64 >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" + echo popd >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" + echo pushd "%%VSINSTALLDIR%%" >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" + echo CALL "VC\Auxiliary\Build\vcvarsall.bat" x86_amd64 >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" + ) ELSE ( + echo CALL "VC\Auxiliary\Build\vcvarsall.bat" x64 %VSDEVCMD_ARGS% >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" + echo popd >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" + echo pushd "%%VSINSTALLDIR%%" >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" + echo CALL "VC\Auxiliary\Build\vcvarsall.bat" x86_amd64 %VSDEVCMD_ARGS% >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" + ) + echo popd >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" + ) else ( + set "target_platform=x86" + echo SET "CMAKE_GENERATOR=Visual Studio %VER% %YEAR%" >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" + echo pushd "%%VSINSTALLDIR%%" >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" + echo CALL "VC\Auxiliary\Build\vcvars32.bat" >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" + echo popd + ) + diff --git a/packaging/vs2019/install_runtime.bat b/packaging/vs2019/install_runtime.bat new file mode 100644 index 0000000000000000000000000000000000000000..e09a5ccfb0f42cc6de2a2f960d31faf2511ae094 --- /dev/null +++ b/packaging/vs2019/install_runtime.bat @@ -0,0 +1,49 @@ +set VC_PATH=x86 +if "%ARCH%"=="64" ( + set VC_PATH=x64 +) + +set MSC_VER=2019 + +rem :: This should always be present for VC installed with VS. Not sure about VC installed with Visual C++ Build Tools 2015 +rem FOR /F "usebackq tokens=3*" %%A IN (`REG QUERY "HKEY_LOCAL_MACHINE\Software\Microsoft\DevDiv\VC\Servicing\14.0\IDE.x64" /v UpdateVersion`) DO ( +rem set SP=%%A +rem ) + +rem if not "%SP%" == "%PKG_VERSION%" ( +rem echo "Version detected from registry: %SP%" +rem echo "does not match version of package being built (%PKG_VERSION%)" +rem echo "Do you have current updates for VS 2015 installed?" +rem exit 1 +rem ) + + +REM ========== REQUIRES Win 10 SDK be installed, or files otherwise copied to location below! +robocopy "C:\Program Files (x86)\Windows Kits\10\Redist\ucrt\DLLs\%VC_PATH%" "%LIBRARY_BIN%" *.dll /E +robocopy "C:\Program Files (x86)\Windows Kits\10\Redist\ucrt\DLLs\%VC_PATH%" "%PREFIX%" *.dll /E +if %ERRORLEVEL% GEQ 8 exit 1 + +REM ========== This one comes from visual studio 2019 +set "VC_VER=142" + +for /f "usebackq tokens=*" %%i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -legacy -products * -version [16^,17^) -property installationPath`) do ( + if exist "%%i" if exist "%%i\VC\Auxiliary\Build\vcvarsall.bat" ( + set "VS15VCVARSALL=%%i\VC\Auxiliary\Build\vcvarsall.bat" + goto :eof + ) +) + +@setlocal +call "%VS15VARSALL%" x64 + +set "REDIST_ROOT=%VCToolsRedistDir%%VC_PATH%" + +robocopy "%REDIST_ROOT%\Microsoft.VC%VC_VER%.CRT" "%LIBRARY_BIN%" *.dll /E +if %ERRORLEVEL% LSS 8 exit 0 +robocopy "%REDIST_ROOT%\Microsoft.VC%VC_VER%.CRT" "%PREFIX%" *.dll /E +if %ERRORLEVEL% LSS 8 exit 0 +robocopy "%REDIST_ROOT%\Microsoft.VC%VC_VER%.OpenMP" "%LIBRARY_BIN%" *.dll /E +if %ERRORLEVEL% LSS 8 exit 0 +robocopy "%REDIST_ROOT%\Microsoft.VC%VC_VER%.OpenMP" "%PREFIX%" *.dll /E +if %ERRORLEVEL% LSS 8 exit 0 +@endlocal diff --git a/packaging/vs2019/meta.yaml b/packaging/vs2019/meta.yaml new file mode 100644 index 0000000000000000000000000000000000000000..94a0ed4db3eb4bdf2dc59b9144bcdf4ade0b75d5 --- /dev/null +++ b/packaging/vs2019/meta.yaml @@ -0,0 +1,24 @@ +{% set vcver="14.2" %} +{% set vcfeature="14" %} +{% set vsyear="2019" %} +{% set fullver="15.4.27004.2010" %} + +package: + name: vs{{ vsyear }} + version: {{ fullver }} + +build: + skip: True [not win] + script_env: + - VSDEVCMD_ARGS # [win] + +outputs: + - name: vs{{ vsyear }}_{{ cross_compiler_target_platform }} + script: install_activate.bat + track_features: + # VS 2019 is binary-compatible with VS 2017/vc 14.1 and 2015/vc14. Tools are "v142". + strong: + - vc{{ vcfeature }} + about: + summary: Activation and version verification of MSVC {{ vcver }} (VS {{ vsyear }}) compiler + license: BSD 3-clause diff --git a/packaging/windows/internal/cuda_install.bat b/packaging/windows/internal/cuda_install.bat new file mode 100644 index 0000000000000000000000000000000000000000..fa4b97a2b5305001123877fe7124c64b03ba72b2 --- /dev/null +++ b/packaging/windows/internal/cuda_install.bat @@ -0,0 +1,235 @@ +@echo on + +if "%CU_VERSION%" == "cpu" ( + echo Skipping for CPU builds + exit /b 0 +) + +set SRC_DIR=%~dp0\.. + +if not exist "%SRC_DIR%\temp_build" mkdir "%SRC_DIR%\temp_build" + +rem in unit test workflow, we get CUDA_VERSION, for example 11.1 +if defined CUDA_VERSION ( + set CUDA_VER=%CUDA_VERSION:.=% +) else ( + set CUDA_VER=%CU_VERSION:cu=% +) + +set /a CUDA_VER=%CU_VERSION:cu=% +set CUDA_VER_MAJOR=%CUDA_VER:~0,-1% +set CUDA_VER_MINOR=%CUDA_VER:~-1,1% +set CUDA_VERSION_STR=%CUDA_VER_MAJOR%.%CUDA_VER_MINOR% + +if %CUDA_VER% EQU 92 goto cuda92 +if %CUDA_VER% EQU 100 goto cuda100 +if %CUDA_VER% EQU 101 goto cuda101 +if %CUDA_VER% EQU 102 goto cuda102 +if %CUDA_VER% EQU 110 goto cuda110 +if %CUDA_VER% EQU 111 goto cuda111 +if %CUDA_VER% EQU 112 goto cuda112 +if %CUDA_VER% EQU 113 goto cuda113 + +echo CUDA %CUDA_VERSION_STR% is not supported +exit /b 1 + +:cuda92 +if not exist "%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cuda_9.2.148_win10.exe --output "%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe" + if errorlevel 1 exit /b 1 + set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe" + set "ARGS=nvcc_9.2 cuobjdump_9.2 nvprune_9.2 cupti_9.2 cublas_9.2 cublas_dev_9.2 cudart_9.2 cufft_9.2 cufft_dev_9.2 curand_9.2 curand_dev_9.2 cusolver_9.2 cusolver_dev_9.2 cusparse_9.2 cusparse_dev_9.2 nvgraph_9.2 nvgraph_dev_9.2 npp_9.2 npp_dev_9.2 nvrtc_9.2 nvrtc_dev_9.2 nvml_dev_9.2" +) + +if not exist "%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cudnn-9.2-windows10-x64-v7.2.1.38.zip --output "%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip" + if errorlevel 1 exit /b 1 + set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip" +) + +goto cuda_common + +:cuda100 + +if not exist "%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cuda_10.0.130_411.31_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe" + if errorlevel 1 exit /b 1 + set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe" + set "ARGS=nvcc_10.0 cuobjdump_10.0 nvprune_10.0 cupti_10.0 cublas_10.0 cublas_dev_10.0 cudart_10.0 cufft_10.0 cufft_dev_10.0 curand_10.0 curand_dev_10.0 cusolver_10.0 cusolver_dev_10.0 cusparse_10.0 cusparse_dev_10.0 nvgraph_10.0 nvgraph_dev_10.0 npp_10.0 npp_dev_10.0 nvrtc_10.0 nvrtc_dev_10.0 nvml_dev_10.0" +) + +if not exist "%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cudnn-10.0-windows10-x64-v7.4.1.5.zip --output "%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip" + if errorlevel 1 exit /b 1 + set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip" +) + +goto cuda_common + +:cuda101 + +if not exist "%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_10.1.243_426.00_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe" + if errorlevel 1 exit /b 1 + set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe" + set "ARGS=nvcc_10.1 cuobjdump_10.1 nvprune_10.1 cupti_10.1 cublas_10.1 cublas_dev_10.1 cudart_10.1 cufft_10.1 cufft_dev_10.1 curand_10.1 curand_dev_10.1 cusolver_10.1 cusolver_dev_10.1 cusparse_10.1 cusparse_dev_10.1 nvgraph_10.1 nvgraph_dev_10.1 npp_10.1 npp_dev_10.1 nvrtc_10.1 nvrtc_dev_10.1 nvml_dev_10.1" +) + +if not exist "%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-10.1-windows10-x64-v7.6.4.38.zip --output "%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip" + if errorlevel 1 exit /b 1 + set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip" +) + +goto cuda_common + +:cuda102 + +if not exist "%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_10.2.89_441.22_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" + if errorlevel 1 exit /b 1 + set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" + set "ARGS=nvcc_10.2 cuobjdump_10.2 nvprune_10.2 cupti_10.2 cublas_10.2 cublas_dev_10.2 cudart_10.2 cufft_10.2 cufft_dev_10.2 curand_10.2 curand_dev_10.2 cusolver_10.2 cusolver_dev_10.2 cusparse_10.2 cusparse_dev_10.2 nvgraph_10.2 nvgraph_dev_10.2 npp_10.2 npp_dev_10.2 nvrtc_10.2 nvrtc_dev_10.2 nvml_dev_10.2" +) + +if not exist "%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-10.2-windows10-x64-v7.6.5.32.zip --output "%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" + if errorlevel 1 exit /b 1 + set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" +) + +if not exist "%SRC_DIR%\temp_build\gpu_driver_dlls.7z" ( + curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "%SRC_DIR%\temp_build\gpu_driver_dlls.zip" + if errorlevel 1 exit /b 1 +) + +echo Installing GPU driver DLLs +7z x %SRC_DIR%\temp_build\gpu_driver_dlls.zip -aoa -o"C:\Windows\System32" + +goto cuda_common + +:cuda110 + +if not exist "%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.0.2_451.48_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe" + if errorlevel 1 exit /b 1 + set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe" + set "ARGS=nvcc_11.0 cuobjdump_11.0 nvprune_11.0 nvprof_11.0 cupti_11.0 cublas_11.0 cublas_dev_11.0 cudart_11.0 cufft_11.0 cufft_dev_11.0 curand_11.0 curand_dev_11.0 cusolver_11.0 cusolver_dev_11.0 cusparse_11.0 cusparse_dev_11.0 npp_11.0 npp_dev_11.0 nvrtc_11.0 nvrtc_dev_11.0 nvml_dev_11.0" +) + +if not exist "%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-11.0-windows-x64-v8.0.4.30.zip --output "%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip" + if errorlevel 1 exit /b 1 + set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip" +) + +goto cuda_common + +:cuda111 + +if not exist "%SRC_DIR%\temp_build\cuda_11.1.0_456.43_win10.exe" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.1.0_456.43_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.1.0_456.43_win10.exe" + if errorlevel 1 exit /b 1 + set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.1.0_456.43_win10.exe" + set "ARGS=nvcc_11.1 cuobjdump_11.1 nvprune_11.1 nvprof_11.1 cupti_11.1 cublas_11.1 cublas_dev_11.1 cudart_11.1 cufft_11.1 cufft_dev_11.1 curand_11.1 curand_dev_11.1 cusolver_11.1 cusolver_dev_11.1 cusparse_11.1 cusparse_dev_11.1 npp_11.1 npp_dev_11.1 nvrtc_11.1 nvrtc_dev_11.1 nvml_dev_11.1" +) + +@REM There is no downloadable driver for Tesla on CUDA 11.1 yet. We will use +@REM the driver inside CUDA +if "%JOB_EXECUTOR%" == "windows-with-nvidia-gpu" set "ARGS=%ARGS% Display.Driver" + +if not exist "%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-11.1-windows-x64-v8.0.5.39.zip --output "%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip" + if errorlevel 1 exit /b 1 + set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip" +) + +goto cuda_common + +:cuda112 + +if not exist "%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe" ( + curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.2.0_460.89_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe" + if errorlevel 1 exit /b 1 + set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe" + set "ARGS=nvcc_11.2 cuobjdump_11.2 nvprune_11.2 nvprof_11.2 cupti_11.2 cublas_11.2 cublas_dev_11.2 cudart_11.2 cufft_11.2 cufft_dev_11.2 curand_11.2 curand_dev_11.2 cusolver_11.2 cusolver_dev_11.2 cusparse_11.2 cusparse_dev_11.2 npp_11.2 npp_dev_11.2 nvrtc_11.2 nvrtc_dev_11.2 nvml_dev_11.2" +) + +if not exist "%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip" ( + curl -k -L http://s3.amazonaws.com/ossci-windows/cudnn-11.2-windows-x64-v8.1.0.77.zip --output "%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip" + if errorlevel 1 exit /b 1 + set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip" +) + +goto cuda_common + +:cuda113 + +set CUDA_INSTALL_EXE=cuda_11.3.0_465.89_win10.exe +if not exist "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" ( + curl -k -L "https://ossci-windows.s3.amazonaws.com/%CUDA_INSTALL_EXE%" --output "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" + if errorlevel 1 exit /b 1 + set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" + set "ARGS=thrust_11.3 nvcc_11.3 cuobjdump_11.3 nvprune_11.3 nvprof_11.3 cupti_11.3 cublas_11.3 cublas_dev_11.3 cudart_11.3 cufft_11.3 cufft_dev_11.3 curand_11.3 curand_dev_11.3 cusolver_11.3 cusolver_dev_11.3 cusparse_11.3 cusparse_dev_11.3 npp_11.3 npp_dev_11.3 nvrtc_11.3 nvrtc_dev_11.3 nvml_dev_11.3" + +) + +set CUDNN_INSTALL_ZIP=cudnn-11.3-windows-x64-v8.2.0.53.zip +if not exist "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" ( + curl -k -L "http://s3.amazonaws.com/ossci-windows/%CUDNN_INSTALL_ZIP%" --output "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" + if errorlevel 1 exit /b 1 + set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" +) + +goto cuda_common + +:cuda_common + +if not exist "%SRC_DIR%\temp_build\NvToolsExt.7z" ( + curl -k -L https://www.dropbox.com/s/9mcolalfdj4n979/NvToolsExt.7z?dl=1 --output "%SRC_DIR%\temp_build\NvToolsExt.7z" + if errorlevel 1 exit /b 1 +) + +echo Installing CUDA toolkit... +7z x %CUDA_SETUP_FILE% -o"%SRC_DIR%\temp_build\cuda" +pushd "%SRC_DIR%\temp_build\cuda" +start /wait setup.exe -s %ARGS% +popd + +echo Installing VS integration... +rem It's for VS 2019 +if "%CUDA_VER_MAJOR%" == "10" ( + xcopy /Y "%SRC_DIR%\temp_build\cuda\CUDAVisualStudioIntegration\extras\visual_studio_integration\MSBuildExtensions\*.*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations" +) +if "%CUDA_VER_MAJOR%" == "11" ( + xcopy /Y "%SRC_DIR%\temp_build\cuda\visual_studio_integration\CUDAVisualStudioIntegration\extras\visual_studio_integration\MSBuildExtensions\*.*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations" +) + +echo Installing NvToolsExt... +7z x %SRC_DIR%\temp_build\NvToolsExt.7z -o"%SRC_DIR%\temp_build\NvToolsExt" +mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64" +mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\include" +mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\lib\x64" +xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\bin\x64\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64" +xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\include\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\include" +xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\lib\x64\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\lib\x64" + +echo Setting up environment... +set "PATH=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin;%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\libnvvp;%PATH%" +set "CUDA_PATH=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%" +set "CUDA_PATH_V%CUDA_VER_MAJOR%_%CUDA_VER_MINOR%=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%" +set "NVTOOLSEXT_PATH=%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64" + +if not exist "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin\nvcc.exe" ( + echo CUDA %CUDA_VERSION_STR% installed failed. + exit /b 1 +) + +echo Installing cuDNN... +7z x %CUDNN_SETUP_FILE% -o"%SRC_DIR%\temp_build\cudnn" +xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\bin\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin" +xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\lib\x64\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\lib\x64" +xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\include\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\include" + +echo Cleaning temp files +rd /s /q "%SRC_DIR%\temp_build" || ver > nul diff --git a/packaging/windows/internal/driver_update.bat b/packaging/windows/internal/driver_update.bat new file mode 100644 index 0000000000000000000000000000000000000000..00b43affc01cc302a3d6c527be197f1adcc0ba2f --- /dev/null +++ b/packaging/windows/internal/driver_update.bat @@ -0,0 +1,25 @@ +set "DRIVER_DOWNLOAD_LINK=https://ossci-windows.s3.amazonaws.com/461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe" +curl --retry 3 -kL %DRIVER_DOWNLOAD_LINK% --output 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe +if errorlevel 1 exit /b 1 + +start /wait 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe -s -noreboot +if errorlevel 1 exit /b 1 + +del 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe || ver > NUL + +setlocal EnableDelayedExpansion +set NVIDIA_GPU_EXISTS=0 +for /F "delims=" %%i in ('wmic path win32_VideoController get name') do ( + set GPUS=%%i + if not "x!GPUS:NVIDIA=!" == "x!GPUS!" ( + SET NVIDIA_GPU_EXISTS=1 + goto gpu_check_end + ) +) +:gpu_check_end +endlocal & set NVIDIA_GPU_EXISTS=%NVIDIA_GPU_EXISTS% + +if "%NVIDIA_GPU_EXISTS%" == "0" ( + echo "CUDA Driver installation Failed" + exit /b 1 +) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..316bd5c7fd404e1f5ae55bf9c33a95a40c9817e7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,20 @@ +torch>=1.4.0 + +# Required for Windows because it's the only available backend +SoundFile; sys_platform == 'win32' + +# Optional for torchaudio.kaldi_io +numpy +kaldi_io + +# Required for tests only: + +# Style-checking for PEP8 +flake8 + +# Used for comparison of outputs in tests +librosa>=0.4.3 +scipy + +# Unit tests with pytest +pytest diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000000000000000000000000000000000000..182c9f7e3f89f5bb58c03504780e961e0ca3aded --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[pydocstyle] +select = D417 # Missing argument descriptions in the docstring diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..c09961cb0fbc39d4bb5b69484343c862334dda5e --- /dev/null +++ b/setup.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python +import os +import re +import shutil +import subprocess +from pathlib import Path +from setuptools import setup, find_packages +import distutils.command.clean + +from build_tools import setup_helpers + +ROOT_DIR = Path(__file__).parent.resolve() + + +def _run_cmd(cmd, default): + try: + return subprocess.check_output(cmd, cwd=ROOT_DIR).decode('ascii').strip() + except Exception: + return default + + +# Creating the version file +version = '0.10.0' +sha = _run_cmd(['git', 'rev-parse', 'HEAD'], default='Unknown') + +if os.getenv('BUILD_VERSION'): + version = os.getenv('BUILD_VERSION') +elif sha != 'Unknown': + version += '+' + sha[:7] +print('-- Building version ' + version) + +version_path = ROOT_DIR / 'torchaudio' / 'version.py' +with open(version_path, 'w') as f: + f.write("__version__ = '{}'\n".format(version)) + f.write("git_version = {}\n".format(repr(sha))) + +pytorch_package_version = os.getenv('PYTORCH_VERSION') + +pytorch_package_dep = 'torch' +if pytorch_package_version is not None: + pytorch_package_dep += "==" + pytorch_package_version + + +class clean(distutils.command.clean.clean): + def run(self): + # Run default behavior first + distutils.command.clean.clean.run(self) + + # Remove torchaudio extension + for path in (ROOT_DIR / 'torchaudio').glob('**/*.so'): + print(f'removing \'{path}\'') + path.unlink() + # Remove build directory + build_dirs = [ + ROOT_DIR / 'build', + ] + for path in build_dirs: + if path.exists(): + print(f'removing \'{path}\' (and everything under it)') + shutil.rmtree(str(path), ignore_errors=True) + + +def _get_packages(): + exclude = [ + "build*", + "test*", + "torchaudio.csrc*", + "third_party*", + "build_tools*", + ] + exclude_prototype = False + branch_name = _run_cmd(['git', 'rev-parse', '--abbrev-ref', 'HEAD'], default=None) + is_on_tag = _run_cmd(['git', 'describe', '--tags', '--exact-match', '@'], default=None) + + if branch_name is not None and branch_name.startswith('release/'): + print('On release branch') + exclude_prototype = True + if is_on_tag is not None and re.match(r'v[\d.]+(-rc\d+)?', is_on_tag): + print('On release tag') + exclude_prototype = True + if exclude_prototype: + print('Excluding torchaudio.prototype from the package.') + exclude.append("torchaudio.prototype") + return find_packages(exclude=exclude) + + +setup( + name="torchaudio", + version=version, + description="An audio package for PyTorch", + url="https://github.com/pytorch/audio", + author="Soumith Chintala, David Pollack, Sean Naren, Peter Goldsborough", + author_email="soumith@pytorch.org", + classifiers=[ + "Environment :: Plugins", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: BSD License", + "Operating System :: MacOS :: MacOS X", + "Operating System :: Microsoft :: Windows", + "Operating System :: POSIX", + "Programming Language :: C++", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: Implementation :: CPython", + "Topic :: Multimedia :: Sound/Audio", + "Topic :: Scientific/Engineering :: Artificial Intelligence" + ], + packages=_get_packages(), + ext_modules=setup_helpers.get_ext_modules(), + cmdclass={ + 'build_ext': setup_helpers.CMakeBuild, + 'clean': clean, + }, + install_requires=[pytorch_package_dep], + zip_safe=False, +) diff --git a/test/integration_tests/__init__.py b/test/integration_tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/integration_tests/conftest.py b/test/integration_tests/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..66adda5a5d3484e6e5f74c8d12462e7718cc2c12 --- /dev/null +++ b/test/integration_tests/conftest.py @@ -0,0 +1,37 @@ +import torch +from torchaudio_unittest.common_utils import get_asset_path +import pytest + + +class GreedyCTCDecoder(torch.nn.Module): + def __init__(self, labels): + super().__init__() + self.labels = labels + + def forward(self, logits: torch.Tensor) -> str: + """Given a sequence logits over labels, get the best path string + + Args: + logits (Tensor): Logit tensors. Shape `[num_seq, num_label]`. + + Returns: + str: The resulting transcript + """ + best_path = torch.argmax(logits, dim=-1) # [num_seq,] + best_path = torch.unique_consecutive(best_path, dim=-1) + hypothesis = [] + for i in best_path: + char = self.labels[i] + if char not in ['', '']: + hypothesis.append(char) + return ''.join(hypothesis) + + +@pytest.fixture +def ctc_decoder(): + return GreedyCTCDecoder + + +@pytest.fixture +def sample_speech_16000_en(): + return get_asset_path('Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac') diff --git a/test/integration_tests/tacotron2_pipeline_test.py b/test/integration_tests/tacotron2_pipeline_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3b81b1847d6f234950bbb2f46728c8ea62594812 --- /dev/null +++ b/test/integration_tests/tacotron2_pipeline_test.py @@ -0,0 +1,28 @@ +from torchaudio.pipelines import ( + TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH, + TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH, + TACOTRON2_WAVERNN_CHAR_LJSPEECH, + TACOTRON2_WAVERNN_PHONE_LJSPEECH, +) +import pytest + + +@pytest.mark.parametrize( + 'bundle', + [ + TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH, + TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH, + TACOTRON2_WAVERNN_CHAR_LJSPEECH, + TACOTRON2_WAVERNN_PHONE_LJSPEECH, + ] +) +def test_tts_models(bundle): + """Smoke test of TTS pipeline""" + text = "Hello world! Text to Speech!" + + processor = bundle.get_text_processor() + tacotron2 = bundle.get_tacotron2() + vocoder = bundle.get_vocoder() + processed, lengths = processor(text) + mel_spec, lengths, _ = tacotron2.infer(processed, lengths) + waveforms, lengths = vocoder(mel_spec, lengths) diff --git a/test/integration_tests/wav2vec2_pipeline_test.py b/test/integration_tests/wav2vec2_pipeline_test.py new file mode 100644 index 0000000000000000000000000000000000000000..012f960ac46d6dcfde7d131a9c6a61a3953210be --- /dev/null +++ b/test/integration_tests/wav2vec2_pipeline_test.py @@ -0,0 +1,70 @@ +import torchaudio +from torchaudio.pipelines import ( + WAV2VEC2_BASE, + WAV2VEC2_LARGE, + WAV2VEC2_LARGE_LV60K, + WAV2VEC2_ASR_BASE_10M, + WAV2VEC2_ASR_BASE_100H, + WAV2VEC2_ASR_BASE_960H, + WAV2VEC2_ASR_LARGE_10M, + WAV2VEC2_ASR_LARGE_100H, + WAV2VEC2_ASR_LARGE_960H, + WAV2VEC2_ASR_LARGE_LV60K_10M, + WAV2VEC2_ASR_LARGE_LV60K_100H, + WAV2VEC2_ASR_LARGE_LV60K_960H, + WAV2VEC2_XLSR53, + HUBERT_BASE, + HUBERT_LARGE, + HUBERT_XLARGE, + HUBERT_ASR_LARGE, + HUBERT_ASR_XLARGE, +) +import pytest + + +@pytest.mark.parametrize( + "bundle", + [ + WAV2VEC2_BASE, + WAV2VEC2_LARGE, + WAV2VEC2_LARGE_LV60K, + WAV2VEC2_XLSR53, + HUBERT_BASE, + HUBERT_LARGE, + HUBERT_XLARGE, + ] +) +def test_pretraining_models(bundle): + """Smoke test of downloading weights for pretraining models""" + bundle.get_model() + + +@pytest.mark.parametrize( + "bundle,expected", + [ + (WAV2VEC2_ASR_BASE_10M, 'I|HAD|THAT|CURIYOSSITY|BESID|ME|AT|THIS|MOMENT|'), + (WAV2VEC2_ASR_BASE_100H, 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'), + (WAV2VEC2_ASR_BASE_960H, 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'), + (WAV2VEC2_ASR_LARGE_10M, 'I|HAD|THAT|CURIOUSITY|BESIDE|ME|AT|THIS|MOMENT|'), + (WAV2VEC2_ASR_LARGE_100H, 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'), + (WAV2VEC2_ASR_LARGE_960H, 'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'), + (WAV2VEC2_ASR_LARGE_LV60K_10M, 'I|HAD|THAT|CURIOUSSITY|BESID|ME|AT|THISS|MOMENT|'), + (WAV2VEC2_ASR_LARGE_LV60K_100H, 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'), + (WAV2VEC2_ASR_LARGE_LV60K_960H, 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'), + (HUBERT_ASR_LARGE, 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|'), + (HUBERT_ASR_XLARGE, 'I|HAVE|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|') + ] +) +def test_finetune_asr_model( + bundle, + expected, + sample_speech_16000_en, + ctc_decoder, +): + """Smoke test of downloading weights for fine-tuning models and simple transcription""" + model = bundle.get_model().eval() + waveform, sample_rate = torchaudio.load(sample_speech_16000_en) + emission, _ = model(waveform) + decoder = ctc_decoder(bundle.get_labels()) + result = decoder(emission[0]) + assert result == expected diff --git a/test/torchaudio_unittest/README.md b/test/torchaudio_unittest/README.md new file mode 100644 index 0000000000000000000000000000000000000000..dd6249fb4abdf94857dbec0a9d3ec14ee4f206e4 --- /dev/null +++ b/test/torchaudio_unittest/README.md @@ -0,0 +1,148 @@ +# Torchaudio Unit Test Suite + +## How to run test + +You can use `pytest` to run `torchaudio`'s test suites. See +https://docs.pytest.org/ for the detail of how to use `pytest` command. + +For testing, please refer to [contributing guide](../../CONTRIBUTING.md) for +the installation of the required and optional packages. + +For running `kaldi`-related tests: + +```bash +export PATH="${PATH}:/src/featbin/" +``` + +Some useful pytest commands: + +```bash +# List up all the tests +pytest test --collect-only +# Run all the test suites +pytest test +# Run tests on sox_effects module +pytest test/torchaudio_unittest/sox_effect +# use -k to apply filter +pytest test/torchaudio_unittest/sox_io_backend -k load # only runs tests where their names contain load +# Some other useful options; +# Stop on the first failure -x +# Run failure fast --ff +# Only rerun the failure --lf +``` + +**Note** +We use PyTorch's test utilities instead of `pytest` frameworks when writing tests to avoid reinventing the wheel for Tensor comparison. +Also, while we recommend using `pytest` for *running* the tests, we cannot +make `pytest` a testing dependency of `torchaudio`. As a result, you should +not import `pytest` or its submodules in the test files; Use the Python +`unittest` builtin module instead, or the `parameterized` package to +parametrize tests. + +## Structure of tests + +The following is an overview of the tests and related modules for `torchaudio`. + +### Purpose specific test suites + +#### Numerical compatibility against existing software +- [Librosa compatibility test](./transforms/librosa_compatibility_test.py) + Test suite for numerical compatibility against librosa. +- [SoX compatibility test](./transforms/sox_compatibility_test.py) + Test suite for numerical compatibility against SoX. +- [Kaldi compatibility test](./transforms/kaldi_compatibility_impl.py) + Test suite for numerical compatibility against Kaldi. + +#### Result consistency with PyTorch framework +- [TorchScript consistency test](./transforms/torchscript_consistency_impl.py) + Test suite to check 1. if an API is TorchScript-able, and 2. the results from Python and Torchscript match. +- [Batch consistency test](./transforms/batch_consistency_test.py) + Test suite to check if functionals/Transforms handle single sample input and batch input and return the same result. + +### Module specific test suites + +The following test modules are defined for corresponding `torchaudio` module/functions. + +- [`torchaudio.datasets`](./datasets) +- [`torchaudio.functional`](./functional) +- [`torchaudio.transforms`](./transforms/transforms_test.py) +- [`torchaudio.compliance.kaldi`](./compliance_kaldi_test.py) +- [`torchaudio.kaldi_io`](./kaldi_io_test.py) +- [`torchaudio.sox_effects`](./sox_effect) +- [`torchaudio.backend`](./backend) + +### Test modules that do not fall into the above categories +- [test_dataloader.py](./dataloader_test.py) + Simple test for loading data and applying preprocessing. + +### Support files +- [assets](./assets): Contain sample audio files. +- [assets/kaldi](./assets/kaldi): Contains Kaldi format matrix files used in [./test_compliance_kaldi.py](./test_compliance_kaldi.py). +- [compliance](./compliance): Scripts used to generate above Kaldi matrix files. + +### Waveforms for Testing Purposes + +When testing transforms we often need waveforms of specific type (ex: pure tone, noise, or voice), with specific bitrate (ex. 8 or 16 kHz) and number of channels (ex. mono, stereo). Below are some tips on how to construct waveforms and guidance around existing audio files. + +#### Load a Waveform from a File + +```python +filepath = common_utils.get_asset_path('filename.wav') +waveform, sample_rate = common_utils.load_wav(filepath) +``` + +*Note: Should you choose to contribute an audio file, please leave a comment in the issue or pull request, mentioning content source and licensing information. WAV files are preferred. Other formats should be used only when there is no alternative. (i.e. dataset implementation comes with hardcoded non-wav extension).* + +#### Pure Tone + +Code: + +```python +waveform = common_utils.get_sinusoid( + frequency=300, + sample_rate=16000, + duration=1, # seconds + n_channels=1, + dtype="float32", + device="cpu", +) +``` + +#### Noise + +Code: + +```python +tensor = common_utils.get_whitenoise() +``` + +Files: + +* `steam-train-whistle-daniel_simon.wav` + +#### Voice + +Files: + +* `CommonVoice/cv-corpus-4-2019-12-10/tt/clips/common_voice_tt_00000000.wav` +* `VCTK-Corpus/wav48/p224/p224_002.wav` +* `vad-go-stereo-44100.wav` +* `vad-go-mono-32000.wav` + +## Adding test + +The following is the current practice of torchaudio test suite. + +1. Unless the tests are related to I/O, use synthetic data. [`common_utils`](./common_utils) has some data generator functions. +1. When you add a new test case, use `common_utils.TorchaudioTestCase` as base class unless you are writing tests that are common to CPU / CUDA. + - Set class memeber `dtype`, `device` and `backend` for the desired behavior. + - If you do not set `backend` value in your test suite, then I/O functions will be unassigned and attempt to load/save file will fail. + - For `backend` value, in addition to available backends, you can also provide the value "default" and backend will be picked automatically based on availability. +1. If you are writing tests that should pass on diffrent dtype/devices, write a common class inheriting `common_utils.TestBaseMixin`, then inherit `common_utils.PytorchTestCase` and define class attributes (`dtype` / `device` / `backend`) there. See [Torchscript consistency test implementation](./transforms/torchscript_consistency_impl.py) and test definitions for [CPU](./transforms/torchscript_consistency_cpu_test.py) and [CUDA](./transforms/torchscript_consistency_cuda_test.py) devices. +1. For numerically comparing Tensors, use `assertEqual` method from torchaudio_unittest.common_utils.PytorchTestCase` class. This method has a better support for a wide variety of Tensor types. + +When you add a new feature(functional/transform), consider the following + +1. When you add a new feature, please make it Torchscript-able and batch-consistent unless it degrades the performance. Please add the tests to see if the new feature meet these requirements. +1. If the feature should be numerical compatible against existing software (SoX, Librosa, Kaldi etc), add a corresponding test. +1. If the new feature is unique to `torchaudio` (not a PyTorch implementation of an existing Software functionality), consider adding correctness tests (wheather the expected output is produced for the set of input) under the corresponding test module (`test_functional.py`, `test_transforms.py`). diff --git a/test/torchaudio_unittest/__init__.py b/test/torchaudio_unittest/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22ac65a3ee70db89a4d3f859dc360ecdabf087ad --- /dev/null +++ b/test/torchaudio_unittest/__init__.py @@ -0,0 +1,4 @@ +try: + from . import fb # noqa +except Exception: + pass diff --git a/test/torchaudio_unittest/assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac b/test/torchaudio_unittest/assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac new file mode 100644 index 0000000000000000000000000000000000000000..8ef93ecc90a213e011464b63538a1d3d9ff28c33 Binary files /dev/null and b/test/torchaudio_unittest/assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac differ diff --git a/test/torchaudio_unittest/assets/VCTK-Corpus/txt/p224/p224_002.txt b/test/torchaudio_unittest/assets/VCTK-Corpus/txt/p224/p224_002.txt new file mode 100644 index 0000000000000000000000000000000000000000..a3d5413ceba6ca081231e0d1ac348f5243a2b71c --- /dev/null +++ b/test/torchaudio_unittest/assets/VCTK-Corpus/txt/p224/p224_002.txt @@ -0,0 +1 @@ +VCTK Test. diff --git a/test/torchaudio_unittest/assets/VCTK-Corpus/wav48/p224/p224_002.wav b/test/torchaudio_unittest/assets/VCTK-Corpus/wav48/p224/p224_002.wav new file mode 100644 index 0000000000000000000000000000000000000000..66ee46737e6fb4a50e41780fea3c0b01b45830b6 Binary files /dev/null and b/test/torchaudio_unittest/assets/VCTK-Corpus/wav48/p224/p224_002.wav differ diff --git a/test/torchaudio_unittest/assets/io/96k_0_1ch.opus b/test/torchaudio_unittest/assets/io/96k_0_1ch.opus new file mode 100644 index 0000000000000000000000000000000000000000..df95474ddb67499f9223537f0081c689b0be654a Binary files /dev/null and b/test/torchaudio_unittest/assets/io/96k_0_1ch.opus differ diff --git a/test/torchaudio_unittest/assets/io/96k_0_2ch.opus b/test/torchaudio_unittest/assets/io/96k_0_2ch.opus new file mode 100644 index 0000000000000000000000000000000000000000..b8837e81e26b579d984f2b871523322d4e8b5399 Binary files /dev/null and b/test/torchaudio_unittest/assets/io/96k_0_2ch.opus differ diff --git a/test/torchaudio_unittest/assets/io/96k_10_1ch.opus b/test/torchaudio_unittest/assets/io/96k_10_1ch.opus new file mode 100644 index 0000000000000000000000000000000000000000..56b170d380376604f308a28cb2ca92ed1f1481cc Binary files /dev/null and b/test/torchaudio_unittest/assets/io/96k_10_1ch.opus differ diff --git a/test/torchaudio_unittest/assets/io/96k_10_2ch.opus b/test/torchaudio_unittest/assets/io/96k_10_2ch.opus new file mode 100644 index 0000000000000000000000000000000000000000..e2b147fc7f94f3a2ab10f6fc8c71af8a3b521b1c Binary files /dev/null and b/test/torchaudio_unittest/assets/io/96k_10_2ch.opus differ diff --git a/test/torchaudio_unittest/assets/io/96k_5_1ch.opus b/test/torchaudio_unittest/assets/io/96k_5_1ch.opus new file mode 100644 index 0000000000000000000000000000000000000000..a1f5214d3ab1c1fd301fee99f13da6f923bff0f7 Binary files /dev/null and b/test/torchaudio_unittest/assets/io/96k_5_1ch.opus differ diff --git a/test/torchaudio_unittest/assets/io/96k_5_2ch.opus b/test/torchaudio_unittest/assets/io/96k_5_2ch.opus new file mode 100644 index 0000000000000000000000000000000000000000..007bc813cea92c6555cde19b2ff0f8e779a59614 Binary files /dev/null and b/test/torchaudio_unittest/assets/io/96k_5_2ch.opus differ diff --git a/test/torchaudio_unittest/assets/io/generate_opus.py b/test/torchaudio_unittest/assets/io/generate_opus.py new file mode 100644 index 0000000000000000000000000000000000000000..e6b99c471c54fb580628f2b2e2730aa4855c369b --- /dev/null +++ b/test/torchaudio_unittest/assets/io/generate_opus.py @@ -0,0 +1,50 @@ +"""Generate opus file for testing load functions""" + +import argparse +import subprocess + +import scipy.io.wavfile +import torch + + +def _parse_args(): + parser = argparse.ArgumentParser( + description='Generate opus files for test' + ) + parser.add_argument('--num-channels', required=True, type=int) + parser.add_argument('--compression-level', required=True, type=int, choices=list(range(11))) + parser.add_argument('--bitrate', default='96k') + return parser.parse_args() + + +def convert_to_opus( + src_path, dst_path, + *, bitrate, compression_level): + """Convert audio file with `ffmpeg` command.""" + command = ['ffmpeg', '-y', '-i', src_path, '-c:a', 'libopus', '-b:a', bitrate] + if compression_level is not None: + command += ['-compression_level', str(compression_level)] + command += [dst_path] + print(' '.join(command)) + subprocess.run(command, check=True) + + +def _generate(num_channels, compression_level, bitrate): + org_path = 'original.wav' + ops_path = f'{bitrate}_{compression_level}_{num_channels}ch.opus' + + # Note: ffmpeg forces sample rate 48k Hz for opus https://stackoverflow.com/a/39186779 + # 1. generate original wav + data = torch.linspace(-32768, 32767, 32768, dtype=torch.int16).repeat([num_channels, 1]).t() + scipy.io.wavfile.write(org_path, 48000, data.numpy()) + # 2. convert to opus + convert_to_opus(org_path, ops_path, bitrate=bitrate, compression_level=compression_level) + + +def _main(): + args = _parse_args() + _generate(args.num_channels, args.compression_level, args.bitrate) + + +if __name__ == '__main__': + _main() diff --git a/test/torchaudio_unittest/assets/kaldi_file.wav b/test/torchaudio_unittest/assets/kaldi_file.wav new file mode 100644 index 0000000000000000000000000000000000000000..66ee46737e6fb4a50e41780fea3c0b01b45830b6 Binary files /dev/null and b/test/torchaudio_unittest/assets/kaldi_file.wav differ diff --git a/test/torchaudio_unittest/assets/kaldi_file_8000.wav b/test/torchaudio_unittest/assets/kaldi_file_8000.wav new file mode 100644 index 0000000000000000000000000000000000000000..01a5755dbadf129f6e8af650174a4b765c8d5234 Binary files /dev/null and b/test/torchaudio_unittest/assets/kaldi_file_8000.wav differ diff --git a/test/torchaudio_unittest/assets/kaldi_test_fbank_args.jsonl b/test/torchaudio_unittest/assets/kaldi_test_fbank_args.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..91d1b57f71574e0bad68e668540ed2d55809ade1 --- /dev/null +++ b/test/torchaudio_unittest/assets/kaldi_test_fbank_args.jsonl @@ -0,0 +1,88 @@ +{"blackman_coeff": 0.0939, "energy_floor": 4.5062, "frame_length": 1.0625, "frame_shift": 0.6875, "high_freq": 1841, "htk_compat": true, "low_freq": 479, "num_mel_bins": 5, "preemphasis_coefficient": 0.84, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "use_log_fbank": true, "use_power": false, "vtln_high": 1832, "vtln_low": 1824, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 0.166, "energy_floor": 1.7875, "frame_length": 1.125, "frame_shift": 0.5, "high_freq": 4999, "htk_compat": true, "low_freq": 1740, "num_mel_bins": 6, "preemphasis_coefficient": 0.29, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": true, "use_log_fbank": false, "use_power": true, "vtln_high": 4587, "vtln_low": 2289, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 0.2215, "energy_floor": 1.3444, "frame_length": 1.125, "frame_shift": 0.75, "high_freq": 7468, "htk_compat": true, "low_freq": 87, "num_mel_bins": 5, "preemphasis_coefficient": 0.17, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": false, "use_energy": true, "use_log_fbank": false, "use_power": false, "vtln_high": 1700, "vtln_low": 870, "vtln_warp": 0.3104, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 0.2512, "energy_floor": 0.2607, "frame_length": 0.875, "frame_shift": 0.875, "high_freq": 7380, "htk_compat": true, "low_freq": 4471, "num_mel_bins": 5, "preemphasis_coefficient": 0.76, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "use_log_fbank": false, "use_power": true, "vtln_high": 7138, "vtln_low": 5172, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 0.2834, "energy_floor": 1.7885, "frame_length": 1.1875, "frame_shift": 0.9375, "high_freq": 5385, "htk_compat": false, "low_freq": 2579, "num_mel_bins": 6, "preemphasis_coefficient": 0.82, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": false, "use_log_fbank": true, "use_power": true, "vtln_high": 4782, "vtln_low": 4492, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 0.3188, "energy_floor": 1.6288, "frame_length": 1.0, "frame_shift": 0.5, "high_freq": 6258, "htk_compat": true, "low_freq": 2043, "num_mel_bins": 4, "preemphasis_coefficient": 0.57, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": false, "use_energy": false, "use_log_fbank": true, "use_power": true, "vtln_high": 5274, "vtln_low": 3268, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 0.3637, "energy_floor": 4.7928, "frame_length": 1.0, "frame_shift": 0.5625, "high_freq": 7671, "htk_compat": false, "low_freq": 2385, "num_mel_bins": 5, "preemphasis_coefficient": 0.81, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": true, "use_log_fbank": false, "use_power": true, "vtln_high": 6881, "vtln_low": 4659, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 0.4702, "energy_floor": 2.668, "frame_length": 1.0, "frame_shift": 1.0, "high_freq": 7231, "htk_compat": true, "low_freq": 1515, "num_mel_bins": 4, "preemphasis_coefficient": 0.92, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": true, "use_energy": false, "use_log_fbank": false, "use_power": true, "vtln_high": 6506, "vtln_low": 2549, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 0.5988, "energy_floor": 0.8014, "frame_length": 1.125, "frame_shift": 0.875, "high_freq": 3663, "htk_compat": true, "low_freq": 1941, "num_mel_bins": 6, "preemphasis_coefficient": 0.59, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": false, "use_log_fbank": false, "use_power": false, "vtln_high": 3373, "vtln_low": 3354, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 0.6421, "energy_floor": 2.1404, "frame_length": 0.75, "frame_shift": 1.0625, "high_freq": 6031, "htk_compat": false, "low_freq": 57, "num_mel_bins": 4, "preemphasis_coefficient": 0.03, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": true, "use_log_fbank": false, "use_power": true, "vtln_high": 5417, "vtln_low": 1170, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 0.674, "energy_floor": 1.3778, "frame_length": 1.0, "frame_shift": 0.875, "high_freq": 6623, "htk_compat": true, "low_freq": 2402, "num_mel_bins": 7, "preemphasis_coefficient": 0.5, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": false, "use_log_fbank": false, "use_power": true, "vtln_high": 6491, "vtln_low": 6262, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 0.7979, "energy_floor": 1.4223, "frame_length": 1.125, "frame_shift": 0.3125, "high_freq": 2534, "htk_compat": false, "low_freq": 810, "num_mel_bins": 4, "preemphasis_coefficient": 0.77, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "use_log_fbank": false, "use_power": true, "vtln_high": 2494, "vtln_low": 2015, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 0.8161, "energy_floor": 1.2937, "frame_length": 0.9375, "frame_shift": 0.125, "high_freq": 5030, "htk_compat": false, "low_freq": 966, "num_mel_bins": 6, "preemphasis_coefficient": 0.03, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": true, "use_log_fbank": true, "use_power": true, "vtln_high": 4652, "vtln_low": 2559, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 0.8873, "energy_floor": 1.2866, "frame_length": 1.125, "frame_shift": 0.4375, "high_freq": 5558, "htk_compat": true, "low_freq": 1464, "num_mel_bins": 8, "preemphasis_coefficient": 0.77, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": false, "use_log_fbank": false, "use_power": false, "vtln_high": 4613, "vtln_low": 4001, "vtln_warp": 1.4073, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 0.8997, "energy_floor": 2.8795, "frame_length": 0.875, "frame_shift": 0.5, "high_freq": 3383, "htk_compat": false, "low_freq": 259, "num_mel_bins": 4, "preemphasis_coefficient": 0.08, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": true, "use_log_fbank": true, "use_power": false, "vtln_high": 1175, "vtln_low": 1038, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 0.9113, "energy_floor": 0.9909, "frame_length": 0.6875, "frame_shift": 0.375, "high_freq": 7562, "htk_compat": true, "low_freq": 3978, "num_mel_bins": 4, "preemphasis_coefficient": 0.03, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": false, "use_log_fbank": true, "use_power": true, "vtln_high": 6483, "vtln_low": 5671, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 0.9312, "energy_floor": 3.3768, "frame_length": 0.8125, "frame_shift": 1.125, "high_freq": 5824, "htk_compat": false, "low_freq": 1366, "num_mel_bins": 6, "preemphasis_coefficient": 0.28, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": false, "use_log_fbank": true, "use_power": true, "vtln_high": 3917, "vtln_low": 1620, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 0.9472, "energy_floor": 2.4134, "frame_length": 0.75, "frame_shift": 1.0, "high_freq": 7959, "htk_compat": false, "low_freq": 1770, "num_mel_bins": 6, "preemphasis_coefficient": 0.12, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": false, "use_log_fbank": false, "use_power": true, "vtln_high": 6874, "vtln_low": 5861, "vtln_warp": 1.1718, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 0.9483, "energy_floor": 4.0177, "frame_length": 1.125, "frame_shift": 1.125, "high_freq": 7854, "htk_compat": false, "low_freq": 4793, "num_mel_bins": 5, "preemphasis_coefficient": 0.47, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": false, "use_energy": true, "use_log_fbank": false, "use_power": false, "vtln_high": 5868, "vtln_low": 5848, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 0.9631, "energy_floor": 3.3222, "frame_length": 1.0, "frame_shift": 0.5, "high_freq": 7662, "htk_compat": true, "low_freq": 1833, "num_mel_bins": 6, "preemphasis_coefficient": 0.71, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": true, "use_energy": false, "use_log_fbank": true, "use_power": true, "vtln_high": 6204, "vtln_low": 5887, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 0.982, "energy_floor": 1.668, "frame_length": 0.75, "frame_shift": 0.1875, "high_freq": 6788, "htk_compat": false, "low_freq": 1968, "num_mel_bins": 4, "preemphasis_coefficient": 0.14, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": true, "use_log_fbank": true, "use_power": false, "vtln_high": 2114, "vtln_low": 2024, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 1.0183, "energy_floor": 2.1572, "frame_length": 1.0625, "frame_shift": 0.375, "high_freq": 2018, "htk_compat": true, "low_freq": 317, "num_mel_bins": 6, "preemphasis_coefficient": 0.14, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": false, "use_log_fbank": false, "use_power": false, "vtln_high": 1118, "vtln_low": 947, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 1.0269, "energy_floor": 0.3681, "frame_length": 1.125, "frame_shift": 0.5625, "high_freq": 4897, "htk_compat": true, "low_freq": 543, "num_mel_bins": 4, "preemphasis_coefficient": 0.57, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": false, "use_log_fbank": false, "use_power": false, "vtln_high": 1226, "vtln_low": 960, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 1.0298, "energy_floor": 2.249, "frame_length": 1.125, "frame_shift": 0.25, "high_freq": 2031, "htk_compat": true, "low_freq": 257, "num_mel_bins": 5, "preemphasis_coefficient": 0.65, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": false, "use_log_fbank": false, "use_power": true, "vtln_high": 1731, "vtln_low": 1582, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 1.2222, "energy_floor": 0.0582, "frame_length": 1.1875, "frame_shift": 0.8125, "high_freq": 6633, "htk_compat": false, "low_freq": 1117, "num_mel_bins": 8, "preemphasis_coefficient": 0.96, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": false, "use_log_fbank": true, "use_power": false, "vtln_high": 4191, "vtln_low": 3264, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 1.2251, "energy_floor": 0.4403, "frame_length": 0.5625, "frame_shift": 0.6875, "high_freq": 3192, "htk_compat": false, "low_freq": 599, "num_mel_bins": 4, "preemphasis_coefficient": 0.75, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": false, "use_log_fbank": false, "use_power": true, "vtln_high": 3183, "vtln_low": 2975, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 1.2278, "energy_floor": 1.4848, "frame_length": 0.6875, "frame_shift": 0.625, "high_freq": 5785, "htk_compat": true, "low_freq": 289, "num_mel_bins": 4, "preemphasis_coefficient": 0.58, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": false, "use_log_fbank": false, "use_power": true, "vtln_high": 4062, "vtln_low": 3715, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 1.3199, "energy_floor": 1.1137, "frame_length": 1.125, "frame_shift": 0.6875, "high_freq": 6702, "htk_compat": true, "low_freq": 390, "num_mel_bins": 6, "preemphasis_coefficient": 0.54, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": true, "use_energy": false, "use_log_fbank": true, "use_power": false, "vtln_high": 4426, "vtln_low": 2811, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 1.3325, "energy_floor": 2.6552, "frame_length": 1.0, "frame_shift": 1.0625, "high_freq": 6444, "htk_compat": true, "low_freq": 759, "num_mel_bins": 4, "preemphasis_coefficient": 0.67, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": false, "use_log_fbank": false, "use_power": false, "vtln_high": 6065, "vtln_low": 4599, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 1.3426, "energy_floor": 1.5712, "frame_length": 1.1875, "frame_shift": 1.0, "high_freq": 7444, "htk_compat": false, "low_freq": 1986, "num_mel_bins": 6, "preemphasis_coefficient": 0.46, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": false, "use_log_fbank": true, "use_power": true, "vtln_high": 4787, "vtln_low": 3163, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 1.4359, "energy_floor": 1.2709, "frame_length": 1.0, "frame_shift": 0.5625, "high_freq": 7657, "htk_compat": true, "low_freq": 1017, "num_mel_bins": 5, "preemphasis_coefficient": 0.93, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": false, "use_log_fbank": false, "use_power": true, "vtln_high": 4228, "vtln_low": 2903, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 1.4621, "energy_floor": 1.2891, "frame_length": 0.875, "frame_shift": 1.0625, "high_freq": 6324, "htk_compat": true, "low_freq": 408, "num_mel_bins": 7, "preemphasis_coefficient": 0.09, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": false, "use_log_fbank": false, "use_power": false, "vtln_high": 6163, "vtln_low": 5973, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 1.4751, "energy_floor": 2.3567, "frame_length": 1.1875, "frame_shift": 1.0, "high_freq": 7115, "htk_compat": false, "low_freq": 4236, "num_mel_bins": 5, "preemphasis_coefficient": 0.65, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "use_log_fbank": true, "use_power": false, "vtln_high": 6523, "vtln_low": 5708, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 1.4883, "energy_floor": 4.1237, "frame_length": 0.75, "frame_shift": 0.25, "high_freq": 5670, "htk_compat": true, "low_freq": 766, "num_mel_bins": 6, "preemphasis_coefficient": 0.29, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": false, "use_log_fbank": false, "use_power": false, "vtln_high": 5479, "vtln_low": 4173, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 1.493, "energy_floor": 1.4719, "frame_length": 1.125, "frame_shift": 0.25, "high_freq": 7805, "htk_compat": true, "low_freq": 5052, "num_mel_bins": 4, "preemphasis_coefficient": 0.9, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": false, "use_energy": true, "use_log_fbank": false, "use_power": false, "vtln_high": 7300, "vtln_low": 5299, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 1.5718, "energy_floor": 3.5447, "frame_length": 0.625, "frame_shift": 0.1875, "high_freq": 6777, "htk_compat": true, "low_freq": 938, "num_mel_bins": 4, "preemphasis_coefficient": 0.69, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": false, "use_energy": true, "use_log_fbank": true, "use_power": false, "vtln_high": 4540, "vtln_low": 3168, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 1.5786, "energy_floor": 1.9016, "frame_length": 1.1875, "frame_shift": 0.75, "high_freq": 5812, "htk_compat": true, "low_freq": 3000, "num_mel_bins": 4, "preemphasis_coefficient": 0.14, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": true, "use_log_fbank": true, "use_power": false, "vtln_high": 4930, "vtln_low": 4316, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 1.6134, "energy_floor": 0.6389, "frame_length": 1.0625, "frame_shift": 0.8125, "high_freq": 7384, "htk_compat": false, "low_freq": 184, "num_mel_bins": 7, "preemphasis_coefficient": 0.08, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": false, "use_log_fbank": false, "use_power": true, "vtln_high": 2759, "vtln_low": 306, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 1.6312, "energy_floor": 2.6556, "frame_length": 0.625, "frame_shift": 0.4375, "high_freq": 5589, "htk_compat": false, "low_freq": 1049, "num_mel_bins": 5, "preemphasis_coefficient": 0.8, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": false, "use_energy": false, "use_log_fbank": false, "use_power": false, "vtln_high": 3816, "vtln_low": 1550, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 1.7515, "energy_floor": 0.5964, "frame_length": 1.0625, "frame_shift": 1.0, "high_freq": 4349, "htk_compat": true, "low_freq": 702, "num_mel_bins": 5, "preemphasis_coefficient": 0.36, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": true, "use_log_fbank": false, "use_power": false, "vtln_high": 4168, "vtln_low": 1531, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 1.8179, "energy_floor": 1.3295, "frame_length": 0.5625, "frame_shift": 0.6875, "high_freq": 4510, "htk_compat": false, "low_freq": 122, "num_mel_bins": 4, "preemphasis_coefficient": 0.56, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "use_log_fbank": false, "use_power": false, "vtln_high": 4365, "vtln_low": 3721, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 1.9387, "energy_floor": 4.7991, "frame_length": 1.0, "frame_shift": 0.375, "high_freq": 6123, "htk_compat": true, "low_freq": 740, "num_mel_bins": 6, "preemphasis_coefficient": 0.21, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": false, "use_energy": true, "use_log_fbank": true, "use_power": true, "vtln_high": 3970, "vtln_low": 3355, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 2.0479, "energy_floor": 1.4296, "frame_length": 1.0625, "frame_shift": 0.6875, "high_freq": 7818, "htk_compat": true, "low_freq": 1628, "num_mel_bins": 8, "preemphasis_coefficient": 0.27, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": true, "use_log_fbank": false, "use_power": true, "vtln_high": 7749, "vtln_low": 7478, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 2.0809, "energy_floor": 1.9752, "frame_length": 0.75, "frame_shift": 1.1875, "high_freq": 5933, "htk_compat": false, "low_freq": 666, "num_mel_bins": 5, "preemphasis_coefficient": 0.72, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": true, "use_energy": true, "use_log_fbank": true, "use_power": true, "vtln_high": 5348, "vtln_low": 4645, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 2.1098, "energy_floor": 2.1356, "frame_length": 1.0625, "frame_shift": 0.9375, "high_freq": 7825, "htk_compat": true, "low_freq": 408, "num_mel_bins": 4, "preemphasis_coefficient": 0.37, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "use_log_fbank": false, "use_power": false, "vtln_high": 5297, "vtln_low": 2747, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 2.1463, "energy_floor": 0.3422, "frame_length": 0.8125, "frame_shift": 0.5, "high_freq": 6892, "htk_compat": true, "low_freq": 65, "num_mel_bins": 4, "preemphasis_coefficient": 0.47, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "use_log_fbank": false, "use_power": true, "vtln_high": 4178, "vtln_low": 2891, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 2.1768, "energy_floor": 3.782, "frame_length": 0.75, "frame_shift": 0.8125, "high_freq": 7063, "htk_compat": false, "low_freq": 2703, "num_mel_bins": 4, "preemphasis_coefficient": 0.99, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": true, "use_energy": false, "use_log_fbank": false, "use_power": true, "vtln_high": 6819, "vtln_low": 3764, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 2.1902, "energy_floor": 4.9973, "frame_length": 1.125, "frame_shift": 0.5, "high_freq": 7066, "htk_compat": false, "low_freq": 1699, "num_mel_bins": 4, "preemphasis_coefficient": 0.95, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": false, "use_log_fbank": false, "use_power": false, "vtln_high": 5452, "vtln_low": 5271, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 2.239, "energy_floor": 2.9557, "frame_length": 1.0625, "frame_shift": 0.875, "high_freq": 7615, "htk_compat": true, "low_freq": 4707, "num_mel_bins": 7, "preemphasis_coefficient": 1.0, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": true, "use_log_fbank": true, "use_power": true, "vtln_high": 6790, "vtln_low": 6501, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 2.3199, "energy_floor": 0.8311, "frame_length": 0.9375, "frame_shift": 0.3125, "high_freq": 6738, "htk_compat": true, "low_freq": 1787, "num_mel_bins": 5, "preemphasis_coefficient": 0.83, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": true, "use_log_fbank": false, "use_power": true, "vtln_high": 6635, "vtln_low": 6360, "vtln_warp": 0.7856, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 2.4071, "energy_floor": 2.7889, "frame_length": 0.9375, "frame_shift": 0.8125, "high_freq": 6598, "htk_compat": true, "low_freq": 2373, "num_mel_bins": 5, "preemphasis_coefficient": 0.2, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": false, "use_log_fbank": true, "use_power": true, "vtln_high": 4565, "vtln_low": 3464, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 2.4586, "energy_floor": 3.3176, "frame_length": 0.625, "frame_shift": 0.75, "high_freq": 7380, "htk_compat": false, "low_freq": 4248, "num_mel_bins": 4, "preemphasis_coefficient": 0.59, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": false, "use_log_fbank": false, "use_power": true, "vtln_high": 7263, "vtln_low": 6361, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 2.6039, "energy_floor": 1.1619, "frame_length": 0.75, "frame_shift": 0.5625, "high_freq": 6578, "htk_compat": true, "low_freq": 551, "num_mel_bins": 7, "preemphasis_coefficient": 0.16, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": false, "use_log_fbank": true, "use_power": true, "vtln_high": 4974, "vtln_low": 3139, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 2.6068, "energy_floor": 3.6411, "frame_length": 1.125, "frame_shift": 1.125, "high_freq": 3078, "htk_compat": false, "low_freq": 1003, "num_mel_bins": 5, "preemphasis_coefficient": 0.12, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": true, "use_log_fbank": true, "use_power": false, "vtln_high": 2920, "vtln_low": 1121, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 2.6192, "energy_floor": 1.7209, "frame_length": 1.0625, "frame_shift": 0.625, "high_freq": 2275, "htk_compat": false, "low_freq": 367, "num_mel_bins": 5, "preemphasis_coefficient": 0.27, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": true, "use_log_fbank": false, "use_power": true, "vtln_high": 1293, "vtln_low": 771, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 2.6578, "energy_floor": 0.9137, "frame_length": 1.1875, "frame_shift": 1.0, "high_freq": 4898, "htk_compat": true, "low_freq": 886, "num_mel_bins": 7, "preemphasis_coefficient": 0.57, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": true, "use_log_fbank": false, "use_power": false, "vtln_high": 3704, "vtln_low": 1013, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 2.7083, "energy_floor": 2.8806, "frame_length": 0.75, "frame_shift": 0.9375, "high_freq": 6605, "htk_compat": false, "low_freq": 3759, "num_mel_bins": 4, "preemphasis_coefficient": 0.9, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": false, "use_log_fbank": false, "use_power": true, "vtln_high": 6542, "vtln_low": 5821, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 2.7704, "energy_floor": 4.5251, "frame_length": 1.125, "frame_shift": 0.875, "high_freq": 3819, "htk_compat": true, "low_freq": 787, "num_mel_bins": 5, "preemphasis_coefficient": 0.23, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "use_log_fbank": false, "use_power": false, "vtln_high": 3368, "vtln_low": 3286, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 2.9255, "energy_floor": 3.4363, "frame_length": 1.125, "frame_shift": 1.0, "high_freq": 7660, "htk_compat": false, "low_freq": 5020, "num_mel_bins": 5, "preemphasis_coefficient": 0.09, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "use_log_fbank": false, "use_power": false, "vtln_high": 7470, "vtln_low": 6783, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 3.0009, "energy_floor": 1.845, "frame_length": 1.0625, "frame_shift": 0.75, "high_freq": 5812, "htk_compat": true, "low_freq": 1287, "num_mel_bins": 6, "preemphasis_coefficient": 0.22, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "use_log_fbank": false, "use_power": false, "vtln_high": 5573, "vtln_low": 4642, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 3.2195, "energy_floor": 2.9858, "frame_length": 1.0625, "frame_shift": 0.0625, "high_freq": 6899, "htk_compat": true, "low_freq": 4117, "num_mel_bins": 6, "preemphasis_coefficient": 0.85, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": false, "use_log_fbank": true, "use_power": true, "vtln_high": 5077, "vtln_low": 4977, "vtln_warp": 0.8739, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 3.3208, "energy_floor": 1.5569, "frame_length": 1.0, "frame_shift": 0.3125, "high_freq": 4556, "htk_compat": false, "low_freq": 334, "num_mel_bins": 5, "preemphasis_coefficient": 0.02, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": false, "use_log_fbank": true, "use_power": false, "vtln_high": 2831, "vtln_low": 696, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 3.3976, "energy_floor": 3.9462, "frame_length": 1.1875, "frame_shift": 0.5625, "high_freq": 6513, "htk_compat": false, "low_freq": 3398, "num_mel_bins": 8, "preemphasis_coefficient": 0.38, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": false, "use_log_fbank": true, "use_power": false, "vtln_high": 5827, "vtln_low": 5388, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 3.5842, "energy_floor": 1.2264, "frame_length": 0.9375, "frame_shift": 1.0, "high_freq": 7744, "htk_compat": false, "low_freq": 195, "num_mel_bins": 5, "preemphasis_coefficient": 0.62, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": true, "use_log_fbank": true, "use_power": false, "vtln_high": 7667, "vtln_low": 2993, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 3.5889, "energy_floor": 3.3559, "frame_length": 1.0, "frame_shift": 1.1875, "high_freq": 7354, "htk_compat": true, "low_freq": 997, "num_mel_bins": 5, "preemphasis_coefficient": 0.98, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": false, "use_log_fbank": false, "use_power": false, "vtln_high": 7088, "vtln_low": 6494, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 3.5936, "energy_floor": 2.1701, "frame_length": 1.0625, "frame_shift": 1.0625, "high_freq": 7407, "htk_compat": true, "low_freq": 3649, "num_mel_bins": 5, "preemphasis_coefficient": 0.65, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": true, "use_log_fbank": true, "use_power": true, "vtln_high": 6878, "vtln_low": 6036, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 3.7002, "energy_floor": 3.567, "frame_length": 1.1875, "frame_shift": 0.625, "high_freq": 4479, "htk_compat": true, "low_freq": 2240, "num_mel_bins": 6, "preemphasis_coefficient": 0.73, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": true, "use_log_fbank": false, "use_power": true, "vtln_high": 4084, "vtln_low": 3955, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 3.7078, "energy_floor": 0.3892, "frame_length": 0.8125, "frame_shift": 0.3125, "high_freq": 7876, "htk_compat": true, "low_freq": 2830, "num_mel_bins": 7, "preemphasis_coefficient": 0.46, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": false, "use_log_fbank": false, "use_power": true, "vtln_high": 4726, "vtln_low": 2918, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 3.7585, "energy_floor": 2.9425, "frame_length": 1.1875, "frame_shift": 1.0, "high_freq": 3277, "htk_compat": true, "low_freq": 2244, "num_mel_bins": 4, "preemphasis_coefficient": 0.76, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": false, "use_log_fbank": false, "use_power": false, "vtln_high": 3158, "vtln_low": 2865, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 3.7772, "energy_floor": 2.8211, "frame_length": 1.0, "frame_shift": 0.1875, "high_freq": 3747, "htk_compat": false, "low_freq": 1244, "num_mel_bins": 4, "preemphasis_coefficient": 0.64, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": false, "use_log_fbank": false, "use_power": true, "vtln_high": 3640, "vtln_low": 2770, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 3.8514, "energy_floor": 3.7933, "frame_length": 1.0625, "frame_shift": 0.5, "high_freq": 4136, "htk_compat": true, "low_freq": 1010, "num_mel_bins": 6, "preemphasis_coefficient": 0.12, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": true, "use_log_fbank": true, "use_power": false, "vtln_high": 2408, "vtln_low": 1892, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 3.8677, "energy_floor": 2.5418, "frame_length": 1.0625, "frame_shift": 0.0625, "high_freq": 3496, "htk_compat": true, "low_freq": 309, "num_mel_bins": 4, "preemphasis_coefficient": 0.47, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": false, "use_log_fbank": true, "use_power": false, "vtln_high": 1490, "vtln_low": 645, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 3.9033, "energy_floor": 2.677, "frame_length": 1.125, "frame_shift": 0.875, "high_freq": 5699, "htk_compat": false, "low_freq": 2960, "num_mel_bins": 7, "preemphasis_coefficient": 0.52, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "use_log_fbank": false, "use_power": false, "vtln_high": 5458, "vtln_low": 5400, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 4.0371, "energy_floor": 3.7559, "frame_length": 1.0625, "frame_shift": 0.8125, "high_freq": 4280, "htk_compat": false, "low_freq": 1207, "num_mel_bins": 4, "preemphasis_coefficient": 0.12, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "use_log_fbank": false, "use_power": true, "vtln_high": 3686, "vtln_low": 2010, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 4.0757, "energy_floor": 4.7442, "frame_length": 0.875, "frame_shift": 1.125, "high_freq": 6363, "htk_compat": true, "low_freq": 1524, "num_mel_bins": 4, "preemphasis_coefficient": 0.32, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": false, "use_energy": true, "use_log_fbank": true, "use_power": false, "vtln_high": 5178, "vtln_low": 4628, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 4.1248, "energy_floor": 2.5255, "frame_length": 0.6875, "frame_shift": 0.6875, "high_freq": 3527, "htk_compat": true, "low_freq": 1701, "num_mel_bins": 4, "preemphasis_coefficient": 0.43, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "use_log_fbank": false, "use_power": true, "vtln_high": 2884, "vtln_low": 1773, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 4.18, "energy_floor": 4.6907, "frame_length": 1.1875, "frame_shift": 0.5625, "high_freq": 7316, "htk_compat": true, "low_freq": 3483, "num_mel_bins": 8, "preemphasis_coefficient": 0.61, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": true, "use_log_fbank": true, "use_power": true, "vtln_high": 5820, "vtln_low": 4635, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 4.2251, "energy_floor": 0.5, "frame_length": 0.875, "frame_shift": 0.625, "high_freq": 7515, "htk_compat": false, "low_freq": 1751, "num_mel_bins": 5, "preemphasis_coefficient": 0.64, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": false, "use_log_fbank": true, "use_power": false, "vtln_high": 7486, "vtln_low": 4238, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 4.3011, "energy_floor": 1.4663, "frame_length": 1.125, "frame_shift": 0.9375, "high_freq": 7804, "htk_compat": false, "low_freq": 1208, "num_mel_bins": 6, "preemphasis_coefficient": 0.18, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": true, "use_log_fbank": true, "use_power": false, "vtln_high": 7421, "vtln_low": 3707, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 4.3252, "energy_floor": 0.7732, "frame_length": 0.625, "frame_shift": 0.6875, "high_freq": 7389, "htk_compat": false, "low_freq": 2071, "num_mel_bins": 4, "preemphasis_coefficient": 0.08, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": true, "use_log_fbank": false, "use_power": false, "vtln_high": 6900, "vtln_low": 2344, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 4.3693, "energy_floor": 3.9073, "frame_length": 0.875, "frame_shift": 0.9375, "high_freq": 6107, "htk_compat": true, "low_freq": 3905, "num_mel_bins": 4, "preemphasis_coefficient": 0.86, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": true, "use_log_fbank": false, "use_power": true, "vtln_high": 5001, "vtln_low": 4046, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 4.3926, "energy_floor": 2.0617, "frame_length": 0.5625, "frame_shift": 0.0625, "high_freq": 4253, "htk_compat": true, "low_freq": 1367, "num_mel_bins": 5, "preemphasis_coefficient": 0.84, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": true, "use_log_fbank": true, "use_power": false, "vtln_high": 2112, "vtln_low": 1445, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 4.4706, "energy_floor": 1.7516, "frame_length": 1.125, "frame_shift": 1.125, "high_freq": 7645, "htk_compat": false, "low_freq": 225, "num_mel_bins": 6, "preemphasis_coefficient": 0.8, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": false, "use_energy": false, "use_log_fbank": false, "use_power": false, "vtln_high": 3717, "vtln_low": 304, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 4.5385, "energy_floor": 2.1519, "frame_length": 1.125, "frame_shift": 0.0625, "high_freq": 5610, "htk_compat": false, "low_freq": 1239, "num_mel_bins": 7, "preemphasis_coefficient": 0.87, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": true, "use_energy": true, "use_log_fbank": true, "use_power": false, "vtln_high": 2231, "vtln_low": 1432, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 4.6337, "energy_floor": 2.902, "frame_length": 0.875, "frame_shift": 1.125, "high_freq": 5072, "htk_compat": true, "low_freq": 826, "num_mel_bins": 4, "preemphasis_coefficient": 0.37, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": false, "use_log_fbank": false, "use_power": false, "vtln_high": 4253, "vtln_low": 2427, "vtln_warp": 0.7049, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 4.7468, "energy_floor": 2.1835, "frame_length": 0.6875, "frame_shift": 1.0, "high_freq": 5153, "htk_compat": true, "low_freq": 943, "num_mel_bins": 5, "preemphasis_coefficient": 0.94, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": false, "use_log_fbank": false, "use_power": false, "vtln_high": 3287, "vtln_low": 1478, "vtln_warp": 0.9406, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 4.7855, "energy_floor": 2.1377, "frame_length": 0.9375, "frame_shift": 0.8125, "high_freq": 4123, "htk_compat": false, "low_freq": 587, "num_mel_bins": 4, "preemphasis_coefficient": 0.92, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": false, "use_log_fbank": true, "use_power": false, "vtln_high": 2588, "vtln_low": 2346, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 4.964, "energy_floor": 3.4931, "frame_length": 1.1875, "frame_shift": 1.0, "high_freq": 4235, "htk_compat": true, "low_freq": 1036, "num_mel_bins": 4, "preemphasis_coefficient": 0.43, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": true, "use_log_fbank": false, "use_power": false, "vtln_high": 3706, "vtln_low": 2840, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} diff --git a/test/torchaudio_unittest/assets/kaldi_test_mfcc_args.jsonl b/test/torchaudio_unittest/assets/kaldi_test_mfcc_args.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..7808a2442e48a7c11c2680ccdb37e2a64cf25509 --- /dev/null +++ b/test/torchaudio_unittest/assets/kaldi_test_mfcc_args.jsonl @@ -0,0 +1,114 @@ +{"blackman_coeff": 0.013, "energy_floor": 1.8509, "frame_length": 1.1875, "frame_shift": 0.625, "high_freq": 7999, "htk_compat": false, "low_freq": 4330, "num_mel_bins": 5, "preemphasis_coefficient": 0.38, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": true, "num_ceps": 3, "cepstral_lifter": 8.1048, "vtln_high": 7497, "vtln_low": 7397, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 0.0487, "energy_floor": 1.3641, "frame_length": 1.0, "frame_shift": 0.8125, "high_freq": 7892, "htk_compat": true, "low_freq": 1904, "num_mel_bins": 8, "preemphasis_coefficient": 0.26, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": false, "num_ceps": 5, "cepstral_lifter": 34.0918, "vtln_high": 4400, "vtln_low": 2737, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 0.0577, "energy_floor": 2.4313, "frame_length": 1.0625, "frame_shift": 0.875, "high_freq": 2922, "htk_compat": true, "low_freq": 274, "num_mel_bins": 6, "preemphasis_coefficient": 0.48, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": true, "num_ceps": 5, "cepstral_lifter": 21.3007, "vtln_high": 1352, "vtln_low": 280, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 0.0718, "energy_floor": 1.3071, "frame_length": 1.1875, "frame_shift": 0.5, "high_freq": 3159, "htk_compat": true, "low_freq": 759, "num_mel_bins": 8, "preemphasis_coefficient": 0.04, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "num_ceps": 6, "cepstral_lifter": 2.6493, "vtln_high": 3145, "vtln_low": 3119, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 0.083, "energy_floor": 2.1607, "frame_length": 0.75, "frame_shift": 0.75, "high_freq": 5872, "htk_compat": true, "low_freq": 708, "num_mel_bins": 5, "preemphasis_coefficient": 0.95, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": false, "use_energy": true, "num_ceps": 4, "cepstral_lifter": 24.5097, "vtln_high": 5231, "vtln_low": 3888, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 0.0933, "energy_floor": 1.577, "frame_length": 1.1875, "frame_shift": 1.1875, "high_freq": 7519, "htk_compat": false, "low_freq": 357, "num_mel_bins": 4, "preemphasis_coefficient": 0.5, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": false, "num_ceps": 2, "cepstral_lifter": 88.0941, "vtln_high": 7042, "vtln_low": 5298, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 0.1048, "energy_floor": 0.5013, "frame_length": 0.75, "frame_shift": 1.0625, "high_freq": 6426, "htk_compat": true, "low_freq": 3613, "num_mel_bins": 5, "preemphasis_coefficient": 0.96, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": true, "num_ceps": 4, "cepstral_lifter": 73.5838, "vtln_high": 5816, "vtln_low": 3997, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 0.1447, "energy_floor": 4.7142, "frame_length": 1.0625, "frame_shift": 0.4375, "high_freq": 7629, "htk_compat": true, "low_freq": 3498, "num_mel_bins": 7, "preemphasis_coefficient": 0.39, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": false, "num_ceps": 3, "cepstral_lifter": 19.4145, "vtln_high": 7169, "vtln_low": 6751, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 0.1473, "energy_floor": 4.9154, "frame_length": 1.125, "frame_shift": 0.875, "high_freq": 3631, "htk_compat": false, "low_freq": 1229, "num_mel_bins": 8, "preemphasis_coefficient": 0.04, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": true, "num_ceps": 3, "cepstral_lifter": 34.479, "vtln_high": 3390, "vtln_low": 1536, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 0.1485, "energy_floor": 4.275, "frame_length": 0.6875, "frame_shift": 0.75, "high_freq": 5222, "htk_compat": true, "low_freq": 311, "num_mel_bins": 5, "preemphasis_coefficient": 0.31, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "num_ceps": 3, "cepstral_lifter": 48.7991, "vtln_high": 4833, "vtln_low": 513, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 0.1829, "energy_floor": 4.5358, "frame_length": 0.9375, "frame_shift": 0.9375, "high_freq": 6148, "htk_compat": true, "low_freq": 455, "num_mel_bins": 5, "preemphasis_coefficient": 0.59, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": false, "num_ceps": 3, "cepstral_lifter": 48.6724, "vtln_high": 3138, "vtln_low": 2247, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 0.2002, "energy_floor": 1.4805, "frame_length": 0.9375, "frame_shift": 0.875, "high_freq": 7621, "htk_compat": false, "low_freq": 2232, "num_mel_bins": 5, "preemphasis_coefficient": 0.03, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": true, "num_ceps": 5, "cepstral_lifter": 69.3653, "vtln_high": 7087, "vtln_low": 2800, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 0.2107, "energy_floor": 3.475, "frame_length": 1.125, "frame_shift": 0.8125, "high_freq": 5701, "htk_compat": true, "low_freq": 1629, "num_mel_bins": 4, "preemphasis_coefficient": 0.09, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": true, "num_ceps": 2, "cepstral_lifter": 77.7066, "vtln_high": 5622, "vtln_low": 5544, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 0.2523, "energy_floor": 0.148, "frame_length": 1.0, "frame_shift": 1.125, "high_freq": 5833, "htk_compat": false, "low_freq": 556, "num_mel_bins": 4, "preemphasis_coefficient": 0.66, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": false, "num_ceps": 4, "cepstral_lifter": 57.0398, "vtln_high": 4519, "vtln_low": 3600, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 0.2959, "energy_floor": 2.3729, "frame_length": 0.625, "frame_shift": 0.5, "high_freq": 6757, "htk_compat": false, "low_freq": 1744, "num_mel_bins": 6, "preemphasis_coefficient": 0.2, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": false, "num_ceps": 2, "cepstral_lifter": 99.871, "vtln_high": 4957, "vtln_low": 3549, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 0.3642, "energy_floor": 3.5246, "frame_length": 0.4375, "frame_shift": 0.9375, "high_freq": 7942, "htk_compat": false, "low_freq": 3282, "num_mel_bins": 4, "preemphasis_coefficient": 0.52, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": false, "num_ceps": 3, "cepstral_lifter": 74.4735, "vtln_high": 5601, "vtln_low": 4966, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 0.4891, "energy_floor": 1.989, "frame_length": 1.125, "frame_shift": 0.875, "high_freq": 3219, "htk_compat": true, "low_freq": 973, "num_mel_bins": 5, "preemphasis_coefficient": 0.91, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "num_ceps": 4, "cepstral_lifter": 85.357, "vtln_high": 3181, "vtln_low": 3129, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 0.5428, "energy_floor": 1.2368, "frame_length": 0.5625, "frame_shift": 1.0, "high_freq": 6700, "htk_compat": false, "low_freq": 749, "num_mel_bins": 4, "preemphasis_coefficient": 0.86, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "num_ceps": 2, "cepstral_lifter": 16.2782, "vtln_high": 5573, "vtln_low": 4988, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 0.5495, "energy_floor": 2.9502, "frame_length": 1.1875, "frame_shift": 1.1875, "high_freq": 3873, "htk_compat": true, "low_freq": 1564, "num_mel_bins": 5, "preemphasis_coefficient": 0.05, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": false, "num_ceps": 5, "cepstral_lifter": 59.0075, "vtln_high": 3870, "vtln_low": 3750, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 0.6457, "energy_floor": 2.0199, "frame_length": 0.875, "frame_shift": 0.8125, "high_freq": 6510, "htk_compat": false, "low_freq": 1482, "num_mel_bins": 4, "preemphasis_coefficient": 0.26, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": false, "use_energy": true, "num_ceps": 2, "cepstral_lifter": 49.7663, "vtln_high": 5461, "vtln_low": 4039, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 0.7031, "energy_floor": 4.038, "frame_length": 1.125, "frame_shift": 0.6875, "high_freq": 6433, "htk_compat": true, "low_freq": 2336, "num_mel_bins": 8, "preemphasis_coefficient": 0.7, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": true, "use_energy": false, "num_ceps": 3, "cepstral_lifter": 18.0061, "vtln_high": 5902, "vtln_low": 3191, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 0.7197, "energy_floor": 3.2075, "frame_length": 1.0625, "frame_shift": 0.25, "high_freq": 4448, "htk_compat": true, "low_freq": 378, "num_mel_bins": 4, "preemphasis_coefficient": 0.31, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": false, "num_ceps": 2, "cepstral_lifter": 71.591, "vtln_high": 3497, "vtln_low": 3331, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 0.7238, "energy_floor": 1.9087, "frame_length": 1.1875, "frame_shift": 0.75, "high_freq": 5457, "htk_compat": true, "low_freq": 1775, "num_mel_bins": 7, "preemphasis_coefficient": 0.48, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "num_ceps": 3, "cepstral_lifter": 29.2858, "vtln_high": 5349, "vtln_low": 3987, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 0.7507, "energy_floor": 0.2754, "frame_length": 0.875, "frame_shift": 0.8125, "high_freq": 6405, "htk_compat": false, "low_freq": 1972, "num_mel_bins": 5, "preemphasis_coefficient": 0.83, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": true, "use_energy": true, "num_ceps": 2, "cepstral_lifter": 52.5962, "vtln_high": 4597, "vtln_low": 4417, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 0.7954, "energy_floor": 0.3451, "frame_length": 1.1875, "frame_shift": 0.625, "high_freq": 4078, "htk_compat": false, "low_freq": 796, "num_mel_bins": 8, "preemphasis_coefficient": 0.47, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": true, "num_ceps": 8, "cepstral_lifter": 42.3128, "vtln_high": 2299, "vtln_low": 1094, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 0.8432, "energy_floor": 1.3765, "frame_length": 0.9375, "frame_shift": 1.1875, "high_freq": 6004, "htk_compat": true, "low_freq": 2302, "num_mel_bins": 4, "preemphasis_coefficient": 0.6, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": true, "num_ceps": 4, "cepstral_lifter": 74.1116, "vtln_high": 4129, "vtln_low": 2898, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 0.8549, "energy_floor": 4.8924, "frame_length": 0.3125, "frame_shift": 1.125, "high_freq": 5643, "htk_compat": true, "low_freq": 956, "num_mel_bins": 4, "preemphasis_coefficient": 0.32, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": true, "num_ceps": 2, "cepstral_lifter": 76.0384, "vtln_high": 2672, "vtln_low": 1762, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 0.9502, "energy_floor": 1.9738, "frame_length": 0.75, "frame_shift": 0.25, "high_freq": 7773, "htk_compat": true, "low_freq": 1205, "num_mel_bins": 7, "preemphasis_coefficient": 0.5, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": true, "num_ceps": 3, "cepstral_lifter": 62.9038, "vtln_high": 7460, "vtln_low": 7174, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 1.0759, "energy_floor": 1.9132, "frame_length": 1.1875, "frame_shift": 0.625, "high_freq": 3529, "htk_compat": false, "low_freq": 227, "num_mel_bins": 8, "preemphasis_coefficient": 0.26, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": true, "num_ceps": 4, "cepstral_lifter": 4.1559, "vtln_high": 1976, "vtln_low": 972, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 1.147, "energy_floor": 4.3972, "frame_length": 0.9375, "frame_shift": 0.75, "high_freq": 6393, "htk_compat": true, "low_freq": 2451, "num_mel_bins": 4, "preemphasis_coefficient": 0.06, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": true, "num_ceps": 4, "cepstral_lifter": 67.4571, "vtln_high": 5460, "vtln_low": 3654, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 1.2069, "energy_floor": 0.4607, "frame_length": 1.125, "frame_shift": 0.5625, "high_freq": 5864, "htk_compat": true, "low_freq": 1512, "num_mel_bins": 7, "preemphasis_coefficient": 0.17, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": true, "num_ceps": 3, "cepstral_lifter": 33.0194, "vtln_high": 5438, "vtln_low": 3920, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 1.3369, "energy_floor": 4.2619, "frame_length": 0.6875, "frame_shift": 0.375, "high_freq": 5307, "htk_compat": true, "low_freq": 666, "num_mel_bins": 6, "preemphasis_coefficient": 0.19, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": true, "num_ceps": 4, "cepstral_lifter": 88.3186, "vtln_high": 4677, "vtln_low": 2590, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 1.3699, "energy_floor": 3.0236, "frame_length": 1.0625, "frame_shift": 0.75, "high_freq": 3720, "htk_compat": true, "low_freq": 1980, "num_mel_bins": 4, "preemphasis_coefficient": 0.13, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": true, "num_ceps": 3, "cepstral_lifter": 57.6793, "vtln_high": 3441, "vtln_low": 3396, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 1.3763, "energy_floor": 1.6574, "frame_length": 1.125, "frame_shift": 0.8125, "high_freq": 2816, "htk_compat": false, "low_freq": 1021, "num_mel_bins": 4, "preemphasis_coefficient": 0.91, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "num_ceps": 2, "cepstral_lifter": 45.2819, "vtln_high": 2547, "vtln_low": 1123, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 1.4097, "energy_floor": 4.1523, "frame_length": 0.875, "frame_shift": 0.375, "high_freq": 6164, "htk_compat": false, "low_freq": 987, "num_mel_bins": 4, "preemphasis_coefficient": 0.06, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": false, "num_ceps": 3, "cepstral_lifter": 75.0589, "vtln_high": 5873, "vtln_low": 5807, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 1.4295, "energy_floor": 3.7938, "frame_length": 1.125, "frame_shift": 0.75, "high_freq": 3382, "htk_compat": false, "low_freq": 471, "num_mel_bins": 4, "preemphasis_coefficient": 0.42, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": true, "use_energy": true, "num_ceps": 2, "cepstral_lifter": 78.588, "vtln_high": 3299, "vtln_low": 2540, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 1.4677, "energy_floor": 4.0728, "frame_length": 1.125, "frame_shift": 1.0625, "high_freq": 7698, "htk_compat": true, "low_freq": 569, "num_mel_bins": 6, "preemphasis_coefficient": 0.5, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": true, "use_energy": false, "num_ceps": 3, "cepstral_lifter": 76.8484, "vtln_high": 7453, "vtln_low": 7251, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 1.4979, "energy_floor": 1.1705, "frame_length": 1.1875, "frame_shift": 0.375, "high_freq": 4474, "htk_compat": true, "low_freq": 1123, "num_mel_bins": 7, "preemphasis_coefficient": 0.09, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": true, "num_ceps": 4, "cepstral_lifter": 38.6407, "vtln_high": 3043, "vtln_low": 2934, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 1.555, "energy_floor": 1.8728, "frame_length": 0.875, "frame_shift": 0.9375, "high_freq": 5191, "htk_compat": true, "low_freq": 2262, "num_mel_bins": 4, "preemphasis_coefficient": 0.24, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": false, "num_ceps": 2, "cepstral_lifter": 11.982, "vtln_high": 4607, "vtln_low": 4483, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 1.5626, "energy_floor": 3.7117, "frame_length": 1.125, "frame_shift": 0.125, "high_freq": 3008, "htk_compat": true, "low_freq": 534, "num_mel_bins": 5, "preemphasis_coefficient": 0.65, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": true, "use_energy": true, "num_ceps": 2, "cepstral_lifter": 75.6661, "vtln_high": 2592, "vtln_low": 1621, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 1.5707, "energy_floor": 3.0409, "frame_length": 0.75, "frame_shift": 0.625, "high_freq": 7441, "htk_compat": true, "low_freq": 1554, "num_mel_bins": 6, "preemphasis_coefficient": 0.95, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "num_ceps": 4, "cepstral_lifter": 80.1563, "vtln_high": 7152, "vtln_low": 6151, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 1.5966, "energy_floor": 2.3442, "frame_length": 0.5625, "frame_shift": 0.5, "high_freq": 7944, "htk_compat": true, "low_freq": 1616, "num_mel_bins": 5, "preemphasis_coefficient": 0.49, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": true, "num_ceps": 2, "cepstral_lifter": 80.8779, "vtln_high": 5720, "vtln_low": 4080, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 1.6229, "energy_floor": 2.0519, "frame_length": 1.1875, "frame_shift": 0.3125, "high_freq": 4871, "htk_compat": true, "low_freq": 1567, "num_mel_bins": 4, "preemphasis_coefficient": 0.79, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": true, "num_ceps": 2, "cepstral_lifter": 42.9569, "vtln_high": 3483, "vtln_low": 3287, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 1.736, "energy_floor": 0.4063, "frame_length": 0.6875, "frame_shift": 0.0625, "high_freq": 6475, "htk_compat": true, "low_freq": 4439, "num_mel_bins": 4, "preemphasis_coefficient": 0.23, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": false, "num_ceps": 2, "cepstral_lifter": 30.0984, "vtln_high": 5450, "vtln_low": 4909, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 1.7411, "energy_floor": 2.0918, "frame_length": 1.0625, "frame_shift": 0.8125, "high_freq": 6107, "htk_compat": true, "low_freq": 2523, "num_mel_bins": 4, "preemphasis_coefficient": 0.69, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": false, "num_ceps": 3, "cepstral_lifter": 92.6839, "vtln_high": 5085, "vtln_low": 4771, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 1.7439, "energy_floor": 2.3782, "frame_length": 0.875, "frame_shift": 1.1875, "high_freq": 7669, "htk_compat": false, "low_freq": 4499, "num_mel_bins": 4, "preemphasis_coefficient": 0.81, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": false, "use_energy": false, "num_ceps": 2, "cepstral_lifter": 95.7035, "vtln_high": 7521, "vtln_low": 7417, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 1.7611, "energy_floor": 4.2965, "frame_length": 0.8125, "frame_shift": 0.6875, "high_freq": 6607, "htk_compat": false, "low_freq": 454, "num_mel_bins": 7, "preemphasis_coefficient": 0.35, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": false, "num_ceps": 2, "cepstral_lifter": 17.8265, "vtln_high": 6387, "vtln_low": 6105, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 1.7893, "energy_floor": 1.8005, "frame_length": 0.625, "frame_shift": 0.375, "high_freq": 2791, "htk_compat": true, "low_freq": 617, "num_mel_bins": 4, "preemphasis_coefficient": 0.96, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": true, "num_ceps": 4, "cepstral_lifter": 93.405, "vtln_high": 1751, "vtln_low": 1690, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 1.8392, "energy_floor": 4.3711, "frame_length": 0.9375, "frame_shift": 0.75, "high_freq": 6978, "htk_compat": true, "low_freq": 453, "num_mel_bins": 4, "preemphasis_coefficient": 0.48, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": false, "num_ceps": 3, "cepstral_lifter": 42.9768, "vtln_high": 6315, "vtln_low": 3995, "vtln_warp": 1.1059, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 1.9561, "energy_floor": 0.8419, "frame_length": 0.8125, "frame_shift": 1.125, "high_freq": 5308, "htk_compat": false, "low_freq": 1471, "num_mel_bins": 5, "preemphasis_coefficient": 0.62, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": true, "num_ceps": 3, "cepstral_lifter": 51.0514, "vtln_high": 5221, "vtln_low": 5071, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 1.998, "energy_floor": 1.6949, "frame_length": 1.125, "frame_shift": 0.8125, "high_freq": 4678, "htk_compat": true, "low_freq": 2340, "num_mel_bins": 5, "preemphasis_coefficient": 0.44, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": false, "use_energy": true, "num_ceps": 3, "cepstral_lifter": 70.1988, "vtln_high": 4041, "vtln_low": 2424, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 2.0106, "energy_floor": 4.5392, "frame_length": 0.6875, "frame_shift": 0.1875, "high_freq": 4776, "htk_compat": true, "low_freq": 1297, "num_mel_bins": 5, "preemphasis_coefficient": 0.14, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": true, "num_ceps": 2, "cepstral_lifter": 37.5837, "vtln_high": 3995, "vtln_low": 2991, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 2.0835, "energy_floor": 2.8454, "frame_length": 0.9375, "frame_shift": 0.3125, "high_freq": 7496, "htk_compat": false, "low_freq": 1207, "num_mel_bins": 6, "preemphasis_coefficient": 0.76, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": false, "num_ceps": 6, "cepstral_lifter": 4.5734, "vtln_high": 3935, "vtln_low": 3932, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 2.2801, "energy_floor": 1.7051, "frame_length": 1.1875, "frame_shift": 0.125, "high_freq": 7958, "htk_compat": true, "low_freq": 561, "num_mel_bins": 4, "preemphasis_coefficient": 0.9, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "num_ceps": 2, "cepstral_lifter": 60.4649, "vtln_high": 7218, "vtln_low": 5709, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 2.315, "energy_floor": 0.4964, "frame_length": 0.5, "frame_shift": 0.4375, "high_freq": 6582, "htk_compat": false, "low_freq": 1010, "num_mel_bins": 4, "preemphasis_coefficient": 0.98, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": false, "num_ceps": 3, "cepstral_lifter": 71.516, "vtln_high": 6157, "vtln_low": 4430, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 2.3236, "energy_floor": 0.7825, "frame_length": 0.8125, "frame_shift": 0.25, "high_freq": 7488, "htk_compat": true, "low_freq": 1363, "num_mel_bins": 4, "preemphasis_coefficient": 0.3, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "num_ceps": 4, "cepstral_lifter": 68.9678, "vtln_high": 3555, "vtln_low": 1851, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 2.3805, "energy_floor": 2.934, "frame_length": 0.75, "frame_shift": 0.25, "high_freq": 6076, "htk_compat": true, "low_freq": 80, "num_mel_bins": 4, "preemphasis_coefficient": 0.85, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": true, "num_ceps": 3, "cepstral_lifter": 31.0805, "vtln_high": 2257, "vtln_low": 1533, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 2.4091, "energy_floor": 2.8812, "frame_length": 1.125, "frame_shift": 0.9375, "high_freq": 6086, "htk_compat": false, "low_freq": 1210, "num_mel_bins": 5, "preemphasis_coefficient": 0.59, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": false, "use_energy": true, "num_ceps": 3, "cepstral_lifter": 96.1612, "vtln_high": 4840, "vtln_low": 1905, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 2.4134, "energy_floor": 2.6379, "frame_length": 1.1875, "frame_shift": 0.375, "high_freq": 3318, "htk_compat": false, "low_freq": 770, "num_mel_bins": 5, "preemphasis_coefficient": 0.6, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": false, "num_ceps": 3, "cepstral_lifter": 73.9427, "vtln_high": 2044, "vtln_low": 1481, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 2.5228, "energy_floor": 3.1056, "frame_length": 1.125, "frame_shift": 1.1875, "high_freq": 5422, "htk_compat": false, "low_freq": 2825, "num_mel_bins": 7, "preemphasis_coefficient": 0.88, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": false, "num_ceps": 7, "cepstral_lifter": 4.6719, "vtln_high": 5337, "vtln_low": 5243, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 2.5577, "energy_floor": 0.7393, "frame_length": 0.8125, "frame_shift": 0.5, "high_freq": 5291, "htk_compat": true, "low_freq": 1445, "num_mel_bins": 5, "preemphasis_coefficient": 0.01, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": false, "num_ceps": 2, "cepstral_lifter": 5.8944, "vtln_high": 4338, "vtln_low": 4330, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 2.5854, "energy_floor": 3.2219, "frame_length": 0.875, "frame_shift": 0.4375, "high_freq": 6924, "htk_compat": false, "low_freq": 4024, "num_mel_bins": 4, "preemphasis_coefficient": 1.0, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": false, "num_ceps": 2, "cepstral_lifter": 0.0578, "vtln_high": 5707, "vtln_low": 5025, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 2.6674, "energy_floor": 2.777, "frame_length": 1.0625, "frame_shift": 0.3125, "high_freq": 3129, "htk_compat": true, "low_freq": 1706, "num_mel_bins": 4, "preemphasis_coefficient": 0.91, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": true, "num_ceps": 3, "cepstral_lifter": 50.9241, "vtln_high": 2593, "vtln_low": 2198, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 2.6816, "energy_floor": 4.0548, "frame_length": 1.1875, "frame_shift": 0.625, "high_freq": 3182, "htk_compat": false, "low_freq": 157, "num_mel_bins": 6, "preemphasis_coefficient": 0.04, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": false, "num_ceps": 5, "cepstral_lifter": 31.3652, "vtln_high": 1203, "vtln_low": 1174, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 2.7879, "energy_floor": 3.3482, "frame_length": 0.6875, "frame_shift": 0.375, "high_freq": 4262, "htk_compat": true, "low_freq": 150, "num_mel_bins": 4, "preemphasis_coefficient": 0.68, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": true, "num_ceps": 2, "cepstral_lifter": 92.794, "vtln_high": 3276, "vtln_low": 1685, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 2.8683, "energy_floor": 3.8162, "frame_length": 1.125, "frame_shift": 0.375, "high_freq": 6620, "htk_compat": false, "low_freq": 3389, "num_mel_bins": 7, "preemphasis_coefficient": 0.83, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "num_ceps": 5, "cepstral_lifter": 82.2365, "vtln_high": 5365, "vtln_low": 4579, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 2.869, "energy_floor": 3.2618, "frame_length": 1.1875, "frame_shift": 0.9375, "high_freq": 5646, "htk_compat": true, "low_freq": 491, "num_mel_bins": 8, "preemphasis_coefficient": 0.89, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": true, "num_ceps": 2, "cepstral_lifter": 59.9812, "vtln_high": 5397, "vtln_low": 2639, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 2.9211, "energy_floor": 4.144, "frame_length": 0.75, "frame_shift": 0.375, "high_freq": 7210, "htk_compat": true, "low_freq": 3666, "num_mel_bins": 4, "preemphasis_coefficient": 0.93, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": false, "num_ceps": 4, "cepstral_lifter": 94.5907, "vtln_high": 6682, "vtln_low": 4979, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 2.9464, "energy_floor": 0.6798, "frame_length": 1.125, "frame_shift": 0.0625, "high_freq": 4445, "htk_compat": true, "low_freq": 323, "num_mel_bins": 6, "preemphasis_coefficient": 0.46, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": false, "num_ceps": 6, "cepstral_lifter": 7.8133, "vtln_high": 3755, "vtln_low": 1137, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 2.9633, "energy_floor": 1.9565, "frame_length": 0.875, "frame_shift": 0.0625, "high_freq": 6835, "htk_compat": false, "low_freq": 649, "num_mel_bins": 5, "preemphasis_coefficient": 0.77, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": false, "use_energy": false, "num_ceps": 2, "cepstral_lifter": 80.8871, "vtln_high": 6691, "vtln_low": 6581, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 2.9697, "energy_floor": 2.0241, "frame_length": 1.125, "frame_shift": 0.6875, "high_freq": 2170, "htk_compat": false, "low_freq": 180, "num_mel_bins": 5, "preemphasis_coefficient": 0.28, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "num_ceps": 3, "cepstral_lifter": 95.8111, "vtln_high": 1266, "vtln_low": 521, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 3.0358, "energy_floor": 1.7295, "frame_length": 1.1875, "frame_shift": 1.0, "high_freq": 7222, "htk_compat": true, "low_freq": 858, "num_mel_bins": 4, "preemphasis_coefficient": 0.16, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": true, "num_ceps": 4, "cepstral_lifter": 64.7537, "vtln_high": 6220, "vtln_low": 5229, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 3.0421, "energy_floor": 3.3343, "frame_length": 1.0, "frame_shift": 0.9375, "high_freq": 6477, "htk_compat": false, "low_freq": 1402, "num_mel_bins": 5, "preemphasis_coefficient": 0.99, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "num_ceps": 4, "cepstral_lifter": 26.1743, "vtln_high": 6381, "vtln_low": 5017, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 3.0919, "energy_floor": 4.5103, "frame_length": 0.625, "frame_shift": 1.0, "high_freq": 5323, "htk_compat": true, "low_freq": 937, "num_mel_bins": 5, "preemphasis_coefficient": 0.95, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": false, "num_ceps": 5, "cepstral_lifter": 82.2405, "vtln_high": 5130, "vtln_low": 5086, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 3.1463, "energy_floor": 4.5068, "frame_length": 0.6875, "frame_shift": 0.3125, "high_freq": 7587, "htk_compat": true, "low_freq": 3542, "num_mel_bins": 7, "preemphasis_coefficient": 0.78, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "num_ceps": 4, "cepstral_lifter": 28.5808, "vtln_high": 7478, "vtln_low": 7326, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 3.2416, "energy_floor": 1.0604, "frame_length": 0.875, "frame_shift": 0.5, "high_freq": 4730, "htk_compat": false, "low_freq": 968, "num_mel_bins": 4, "preemphasis_coefficient": 0.11, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": true, "num_ceps": 3, "cepstral_lifter": 60.5751, "vtln_high": 3542, "vtln_low": 1943, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 3.2698, "energy_floor": 3.0361, "frame_length": 1.0625, "frame_shift": 0.75, "high_freq": 4870, "htk_compat": true, "low_freq": 1281, "num_mel_bins": 7, "preemphasis_coefficient": 0.64, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": true, "num_ceps": 3, "cepstral_lifter": 28.536, "vtln_high": 4401, "vtln_low": 3315, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 3.3078, "energy_floor": 4.9217, "frame_length": 1.0, "frame_shift": 0.3125, "high_freq": 6758, "htk_compat": true, "low_freq": 760, "num_mel_bins": 5, "preemphasis_coefficient": 0.98, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": true, "num_ceps": 3, "cepstral_lifter": 97.4694, "vtln_high": 6022, "vtln_low": 5650, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 3.3206, "energy_floor": 0.023, "frame_length": 1.0625, "frame_shift": 0.5625, "high_freq": 5744, "htk_compat": true, "low_freq": 3901, "num_mel_bins": 5, "preemphasis_coefficient": 0.94, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": false, "num_ceps": 2, "cepstral_lifter": 47.6031, "vtln_high": 5741, "vtln_low": 5524, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 3.4022, "energy_floor": 1.2172, "frame_length": 0.875, "frame_shift": 0.375, "high_freq": 7737, "htk_compat": false, "low_freq": 612, "num_mel_bins": 5, "preemphasis_coefficient": 0.35, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": false, "num_ceps": 2, "cepstral_lifter": 65.1166, "vtln_high": 6852, "vtln_low": 5820, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 3.4339, "energy_floor": 2.6197, "frame_length": 1.125, "frame_shift": 0.1875, "high_freq": 3341, "htk_compat": true, "low_freq": 1275, "num_mel_bins": 7, "preemphasis_coefficient": 0.41, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": true, "num_ceps": 2, "cepstral_lifter": 76.6062, "vtln_high": 3005, "vtln_low": 1680, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 3.4369, "energy_floor": 3.9198, "frame_length": 0.9375, "frame_shift": 0.125, "high_freq": 6218, "htk_compat": true, "low_freq": 904, "num_mel_bins": 5, "preemphasis_coefficient": 0.47, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": false, "num_ceps": 5, "cepstral_lifter": 92.8036, "vtln_high": 4870, "vtln_low": 1901, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 3.4557, "energy_floor": 1.5553, "frame_length": 0.75, "frame_shift": 0.375, "high_freq": 6642, "htk_compat": true, "low_freq": 1530, "num_mel_bins": 4, "preemphasis_coefficient": 0.72, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": false, "use_energy": true, "num_ceps": 2, "cepstral_lifter": 28.8828, "vtln_high": 4490, "vtln_low": 2980, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 3.4753, "energy_floor": 4.7166, "frame_length": 0.75, "frame_shift": 0.3125, "high_freq": 7637, "htk_compat": true, "low_freq": 4992, "num_mel_bins": 4, "preemphasis_coefficient": 0.92, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": true, "num_ceps": 3, "cepstral_lifter": 2.5456, "vtln_high": 6925, "vtln_low": 5486, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 3.5134, "energy_floor": 2.0285, "frame_length": 0.625, "frame_shift": 0.1875, "high_freq": 5229, "htk_compat": false, "low_freq": 595, "num_mel_bins": 4, "preemphasis_coefficient": 0.65, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": false, "num_ceps": 3, "cepstral_lifter": 30.6069, "vtln_high": 5090, "vtln_low": 3467, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 3.5212, "energy_floor": 3.8251, "frame_length": 0.875, "frame_shift": 0.3125, "high_freq": 4092, "htk_compat": true, "low_freq": 545, "num_mel_bins": 5, "preemphasis_coefficient": 0.09, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": false, "num_ceps": 2, "cepstral_lifter": 2.9422, "vtln_high": 1634, "vtln_low": 1000, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 3.5261, "energy_floor": 2.0251, "frame_length": 0.875, "frame_shift": 0.625, "high_freq": 7926, "htk_compat": false, "low_freq": 3916, "num_mel_bins": 7, "preemphasis_coefficient": 0.45, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": false, "num_ceps": 2, "cepstral_lifter": 48.8818, "vtln_high": 7889, "vtln_low": 7527, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 3.539, "energy_floor": 1.6456, "frame_length": 1.125, "frame_shift": 0.1875, "high_freq": 5425, "htk_compat": true, "low_freq": 2326, "num_mel_bins": 7, "preemphasis_coefficient": 0.72, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": false, "num_ceps": 4, "cepstral_lifter": 39.4555, "vtln_high": 4290, "vtln_low": 2715, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 3.5643, "energy_floor": 2.2424, "frame_length": 1.1875, "frame_shift": 0.875, "high_freq": 2140, "htk_compat": true, "low_freq": 59, "num_mel_bins": 4, "preemphasis_coefficient": 0.98, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": false, "num_ceps": 2, "cepstral_lifter": 36.7118, "vtln_high": 1463, "vtln_low": 1358, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 3.5959, "energy_floor": 4.8866, "frame_length": 1.125, "frame_shift": 1.0625, "high_freq": 5150, "htk_compat": false, "low_freq": 3697, "num_mel_bins": 4, "preemphasis_coefficient": 0.46, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": true, "num_ceps": 3, "cepstral_lifter": 74.0966, "vtln_high": 4277, "vtln_low": 3777, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 3.7223, "energy_floor": 3.4282, "frame_length": 1.0, "frame_shift": 0.125, "high_freq": 6601, "htk_compat": true, "low_freq": 1923, "num_mel_bins": 6, "preemphasis_coefficient": 0.05, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": true, "num_ceps": 4, "cepstral_lifter": 19.2839, "vtln_high": 6596, "vtln_low": 6594, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 3.7376, "energy_floor": 0.2093, "frame_length": 1.1875, "frame_shift": 0.75, "high_freq": 7830, "htk_compat": true, "low_freq": 4448, "num_mel_bins": 5, "preemphasis_coefficient": 0.27, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": true, "num_ceps": 5, "cepstral_lifter": 5.5865, "vtln_high": 5459, "vtln_low": 5056, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 3.812, "energy_floor": 0.4393, "frame_length": 0.75, "frame_shift": 1.0625, "high_freq": 5917, "htk_compat": false, "low_freq": 1272, "num_mel_bins": 4, "preemphasis_coefficient": 0.97, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": true, "use_energy": true, "num_ceps": 3, "cepstral_lifter": 91.4723, "vtln_high": 3532, "vtln_low": 3056, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 3.8613, "energy_floor": 4.6574, "frame_length": 1.125, "frame_shift": 1.0, "high_freq": 3399, "htk_compat": true, "low_freq": 1576, "num_mel_bins": 5, "preemphasis_coefficient": 0.71, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": true, "num_ceps": 3, "cepstral_lifter": 29.1497, "vtln_high": 2440, "vtln_low": 1852, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 3.9117, "energy_floor": 4.6803, "frame_length": 0.5625, "frame_shift": 0.625, "high_freq": 5009, "htk_compat": false, "low_freq": 2542, "num_mel_bins": 4, "preemphasis_coefficient": 0.25, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": false, "num_ceps": 4, "cepstral_lifter": 65.8362, "vtln_high": 4734, "vtln_low": 3050, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 4.1851, "energy_floor": 3.5211, "frame_length": 1.125, "frame_shift": 0.875, "high_freq": 4768, "htk_compat": false, "low_freq": 562, "num_mel_bins": 4, "preemphasis_coefficient": 0.05, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": false, "num_ceps": 2, "cepstral_lifter": 36.961, "vtln_high": 1982, "vtln_low": 741, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 4.2197, "energy_floor": 3.7252, "frame_length": 0.9375, "frame_shift": 0.8125, "high_freq": 7453, "htk_compat": true, "low_freq": 1561, "num_mel_bins": 4, "preemphasis_coefficient": 0.06, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": false, "num_ceps": 2, "cepstral_lifter": 44.78, "vtln_high": 6612, "vtln_low": 4074, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 4.2736, "energy_floor": 4.9552, "frame_length": 0.75, "frame_shift": 1.0, "high_freq": 5145, "htk_compat": false, "low_freq": 1705, "num_mel_bins": 4, "preemphasis_coefficient": 0.33, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": false, "use_energy": false, "num_ceps": 4, "cepstral_lifter": 70.9332, "vtln_high": 4857, "vtln_low": 2223, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 4.3762, "energy_floor": 4.7209, "frame_length": 0.9375, "frame_shift": 0.0625, "high_freq": 5564, "htk_compat": true, "low_freq": 712, "num_mel_bins": 4, "preemphasis_coefficient": 0.74, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "num_ceps": 2, "cepstral_lifter": 39.2887, "vtln_high": 4353, "vtln_low": 3521, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 4.4229, "energy_floor": 0.4222, "frame_length": 1.0625, "frame_shift": 1.1875, "high_freq": 7822, "htk_compat": true, "low_freq": 4837, "num_mel_bins": 5, "preemphasis_coefficient": 0.04, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": false, "num_ceps": 5, "cepstral_lifter": 36.9181, "vtln_high": 7261, "vtln_low": 5703, "vtln_warp": 1.0, "window_type": "blackman", "dither": 0.0} +{"blackman_coeff": 4.4663, "energy_floor": 3.5767, "frame_length": 1.125, "frame_shift": 1.125, "high_freq": 5844, "htk_compat": false, "low_freq": 799, "num_mel_bins": 7, "preemphasis_coefficient": 0.37, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": false, "num_ceps": 4, "cepstral_lifter": 34.2098, "vtln_high": 4554, "vtln_low": 1148, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 4.53, "energy_floor": 3.1492, "frame_length": 1.0625, "frame_shift": 0.375, "high_freq": 7706, "htk_compat": false, "low_freq": 3813, "num_mel_bins": 6, "preemphasis_coefficient": 0.74, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": false, "num_ceps": 2, "cepstral_lifter": 71.8337, "vtln_high": 7672, "vtln_low": 5265, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 4.5474, "energy_floor": 0.7883, "frame_length": 0.5625, "frame_shift": 1.0, "high_freq": 7283, "htk_compat": false, "low_freq": 2418, "num_mel_bins": 4, "preemphasis_coefficient": 0.68, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": true, "num_ceps": 4, "cepstral_lifter": 70.0635, "vtln_high": 7277, "vtln_low": 7265, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 4.5663, "energy_floor": 1.127, "frame_length": 1.125, "frame_shift": 0.8125, "high_freq": 6069, "htk_compat": true, "low_freq": 167, "num_mel_bins": 8, "preemphasis_coefficient": 0.68, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "use_energy": true, "num_ceps": 6, "cepstral_lifter": 1.624, "vtln_high": 2148, "vtln_low": 461, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 4.5896, "energy_floor": 2.7617, "frame_length": 1.0625, "frame_shift": 0.8125, "high_freq": 3851, "htk_compat": true, "low_freq": 1115, "num_mel_bins": 4, "preemphasis_coefficient": 0.03, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": false, "num_ceps": 4, "cepstral_lifter": 49.7637, "vtln_high": 2897, "vtln_low": 2701, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 4.6128, "energy_floor": 0.1203, "frame_length": 1.1875, "frame_shift": 0.9375, "high_freq": 6901, "htk_compat": false, "low_freq": 3577, "num_mel_bins": 6, "preemphasis_coefficient": 0.25, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": false, "num_ceps": 3, "cepstral_lifter": 70.5509, "vtln_high": 5962, "vtln_low": 4190, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 4.6262, "energy_floor": 4.1656, "frame_length": 1.1875, "frame_shift": 0.8125, "high_freq": 6147, "htk_compat": false, "low_freq": 1684, "num_mel_bins": 6, "preemphasis_coefficient": 0.58, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "use_energy": true, "num_ceps": 5, "cepstral_lifter": 54.2056, "vtln_high": 5259, "vtln_low": 2363, "vtln_warp": 1.0, "window_type": "rectangular", "dither": 0.0} +{"blackman_coeff": 4.6741, "energy_floor": 4.3867, "frame_length": 1.125, "frame_shift": 1.125, "high_freq": 6273, "htk_compat": false, "low_freq": 2481, "num_mel_bins": 4, "preemphasis_coefficient": 0.15, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": false, "use_energy": false, "num_ceps": 4, "cepstral_lifter": 75.6122, "vtln_high": 3701, "vtln_low": 2992, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 4.6765, "energy_floor": 1.2644, "frame_length": 1.125, "frame_shift": 0.75, "high_freq": 5204, "htk_compat": false, "low_freq": 276, "num_mel_bins": 4, "preemphasis_coefficient": 0.04, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": true, "use_energy": false, "num_ceps": 3, "cepstral_lifter": 96.116, "vtln_high": 5148, "vtln_low": 2541, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 4.7216, "energy_floor": 2.4818, "frame_length": 0.8125, "frame_shift": 0.375, "high_freq": 6723, "htk_compat": true, "low_freq": 2352, "num_mel_bins": 6, "preemphasis_coefficient": 0.14, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": false, "num_ceps": 3, "cepstral_lifter": 32.0303, "vtln_high": 5598, "vtln_low": 2579, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 4.7919, "energy_floor": 2.6435, "frame_length": 0.625, "frame_shift": 0.5, "high_freq": 7971, "htk_compat": false, "low_freq": 1812, "num_mel_bins": 4, "preemphasis_coefficient": 0.65, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": true, "use_energy": false, "num_ceps": 4, "cepstral_lifter": 27.7648, "vtln_high": 7735, "vtln_low": 7419, "vtln_warp": 1.0, "window_type": "povey", "dither": 0.0} +{"blackman_coeff": 4.814, "energy_floor": 0.468, "frame_length": 1.0625, "frame_shift": 0.6875, "high_freq": 5252, "htk_compat": true, "low_freq": 569, "num_mel_bins": 6, "preemphasis_coefficient": 0.85, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "use_energy": false, "num_ceps": 6, "cepstral_lifter": 56.449, "vtln_high": 4397, "vtln_low": 4332, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} +{"blackman_coeff": 4.95, "energy_floor": 4.5916, "frame_length": 1.125, "frame_shift": 1.0625, "high_freq": 5044, "htk_compat": true, "low_freq": 617, "num_mel_bins": 8, "preemphasis_coefficient": 0.89, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": true, "use_energy": true, "num_ceps": 8, "cepstral_lifter": 23.3238, "vtln_high": 2732, "vtln_low": 2677, "vtln_warp": 1.0, "window_type": "hanning", "dither": 0.0} +{"blackman_coeff": 4.9663, "energy_floor": 4.7867, "frame_length": 1.1875, "frame_shift": 0.5, "high_freq": 2424, "htk_compat": false, "low_freq": 350, "num_mel_bins": 4, "preemphasis_coefficient": 0.39, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "use_energy": true, "num_ceps": 2, "cepstral_lifter": 59.4319, "vtln_high": 1202, "vtln_low": 1063, "vtln_warp": 1.0, "window_type": "hamming", "dither": 0.0} diff --git a/test/torchaudio_unittest/assets/kaldi_test_pitch_args.jsonl b/test/torchaudio_unittest/assets/kaldi_test_pitch_args.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..9844bd6c7246a3905e42e9eccd7c93b43e235c1a --- /dev/null +++ b/test/torchaudio_unittest/assets/kaldi_test_pitch_args.jsonl @@ -0,0 +1,5 @@ +{"sample_rate": 8000} +{"sample_rate": 8000, "frames_per_chunk": 200} +{"sample_rate": 8000, "frames_per_chunk": 200, "simulate_first_pass_online": true} +{"sample_rate": 16000} +{"sample_rate": 44100} diff --git a/test/torchaudio_unittest/assets/kaldi_test_spectrogram_args.jsonl b/test/torchaudio_unittest/assets/kaldi_test_spectrogram_args.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..c44245b6c06ef215e9e3011373374ff8f82ceada --- /dev/null +++ b/test/torchaudio_unittest/assets/kaldi_test_spectrogram_args.jsonl @@ -0,0 +1,109 @@ +{"blackman_coeff": 0.0016, "dither": 0, "energy_floor": 4.668, "frame_length": 0.625, "frame_shift": 0.25, "preemphasis_coefficient": 0.82, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": false, "window_type": "povey"} +{"blackman_coeff": 0.0121, "dither": 0, "energy_floor": 4.9643, "frame_length": 0.875, "frame_shift": 0.1875, "preemphasis_coefficient": 0.98, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": false, "window_type": "rectangular"} +{"blackman_coeff": 0.0378, "dither": 0, "energy_floor": 3.777, "frame_length": 0.5, "frame_shift": 0.625, "preemphasis_coefficient": 0.76, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "window_type": "blackman"} +{"blackman_coeff": 0.0545, "dither": 0, "energy_floor": 0.0732, "frame_length": 1.0, "frame_shift": 0.75, "preemphasis_coefficient": 0.81, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "window_type": "hanning"} +{"blackman_coeff": 0.1005, "dither": 0, "energy_floor": 0.3739, "frame_length": 0.5625, "frame_shift": 0.625, "preemphasis_coefficient": 0.19, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "window_type": "blackman"} +{"blackman_coeff": 0.1088, "dither": 0, "energy_floor": 0.6933, "frame_length": 0.5, "frame_shift": 0.75, "preemphasis_coefficient": 0.51, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": false, "window_type": "povey"} +{"blackman_coeff": 0.1777, "dither": 0, "energy_floor": 3.8992, "frame_length": 1.0, "frame_shift": 0.3125, "preemphasis_coefficient": 0.96, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": false, "window_type": "blackman"} +{"blackman_coeff": 0.2384, "dither": 0, "energy_floor": 0.308, "frame_length": 0.375, "frame_shift": 0.25, "preemphasis_coefficient": 0.98, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": true, "window_type": "povey"} +{"blackman_coeff": 0.2669, "dither": 0, "energy_floor": 2.4329, "frame_length": 0.625, "frame_shift": 1.1875, "preemphasis_coefficient": 0.18, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": false, "window_type": "blackman"} +{"blackman_coeff": 0.334, "dither": 0, "energy_floor": 0.5962, "frame_length": 0.25, "frame_shift": 0.5625, "preemphasis_coefficient": 0.38, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "window_type": "hamming"} +{"blackman_coeff": 0.4268, "dither": 0, "energy_floor": 2.4431, "frame_length": 0.5625, "frame_shift": 0.0625, "preemphasis_coefficient": 0.95, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "window_type": "hamming"} +{"blackman_coeff": 0.4774, "dither": 0, "energy_floor": 0.6982, "frame_length": 1.125, "frame_shift": 1.125, "preemphasis_coefficient": 0.27, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": true, "window_type": "povey"} +{"blackman_coeff": 0.4992, "dither": 0, "energy_floor": 3.7665, "frame_length": 0.4375, "frame_shift": 1.125, "preemphasis_coefficient": 0.42, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "window_type": "rectangular"} +{"blackman_coeff": 0.544, "dither": 0, "energy_floor": 1.6641, "frame_length": 0.9375, "frame_shift": 0.875, "preemphasis_coefficient": 0.13, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "window_type": "rectangular"} +{"blackman_coeff": 0.5785, "dither": 0, "energy_floor": 2.8162, "frame_length": 1.125, "frame_shift": 1.0625, "preemphasis_coefficient": 0.17, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": false, "window_type": "blackman"} +{"blackman_coeff": 0.8072, "dither": 0, "energy_floor": 4.0404, "frame_length": 0.5, "frame_shift": 1.1875, "preemphasis_coefficient": 0.74, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": true, "window_type": "hamming"} +{"blackman_coeff": 0.8418, "dither": 0, "energy_floor": 4.1771, "frame_length": 0.3125, "frame_shift": 0.25, "preemphasis_coefficient": 0.48, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "window_type": "hanning"} +{"blackman_coeff": 0.8431, "dither": 0, "energy_floor": 0.0728, "frame_length": 0.75, "frame_shift": 0.8125, "preemphasis_coefficient": 0.1, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "window_type": "povey"} +{"blackman_coeff": 0.885, "dither": 0, "energy_floor": 3.9292, "frame_length": 0.375, "frame_shift": 0.75, "preemphasis_coefficient": 0.27, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": false, "window_type": "blackman"} +{"blackman_coeff": 0.9625, "dither": 0, "energy_floor": 2.5481, "frame_length": 0.6875, "frame_shift": 1.0, "preemphasis_coefficient": 0.06, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "window_type": "blackman"} +{"blackman_coeff": 0.9826, "dither": 0, "energy_floor": 0.7377, "frame_length": 0.375, "frame_shift": 0.6875, "preemphasis_coefficient": 0.7, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "window_type": "rectangular"} +{"blackman_coeff": 0.9854, "dither": 0, "energy_floor": 3.8819, "frame_length": 0.25, "frame_shift": 1.0, "preemphasis_coefficient": 0.54, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": true, "window_type": "povey"} +{"blackman_coeff": 1.0303, "dither": 0, "energy_floor": 4.4583, "frame_length": 0.375, "frame_shift": 0.875, "preemphasis_coefficient": 0.39, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": false, "window_type": "blackman"} +{"blackman_coeff": 1.0743, "dither": 0, "energy_floor": 0.4642, "frame_length": 1.125, "frame_shift": 0.625, "preemphasis_coefficient": 0.39, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "window_type": "hanning"} +{"blackman_coeff": 1.0788, "dither": 0, "energy_floor": 1.442, "frame_length": 0.1875, "frame_shift": 0.3125, "preemphasis_coefficient": 0.53, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "window_type": "hanning"} +{"blackman_coeff": 1.0816, "dither": 0, "energy_floor": 0.205, "frame_length": 0.1875, "frame_shift": 0.6875, "preemphasis_coefficient": 0.02, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "window_type": "hamming"} +{"blackman_coeff": 1.1385, "dither": 0, "energy_floor": 4.738, "frame_length": 0.625, "frame_shift": 0.3125, "preemphasis_coefficient": 0.23, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": false, "window_type": "hanning"} +{"blackman_coeff": 1.3142, "dither": 0, "energy_floor": 4.8914, "frame_length": 0.875, "frame_shift": 0.1875, "preemphasis_coefficient": 0.34, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": false, "window_type": "blackman"} +{"blackman_coeff": 1.3189, "dither": 0, "energy_floor": 3.683, "frame_length": 1.125, "frame_shift": 1.125, "preemphasis_coefficient": 0.88, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": true, "window_type": "hamming"} +{"blackman_coeff": 1.3235, "dither": 0, "energy_floor": 3.8538, "frame_length": 0.25, "frame_shift": 1.0625, "preemphasis_coefficient": 0.07, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "window_type": "blackman"} +{"blackman_coeff": 1.3389, "dither": 0, "energy_floor": 1.6152, "frame_length": 0.375, "frame_shift": 0.5, "preemphasis_coefficient": 0.21, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": false, "window_type": "blackman"} +{"blackman_coeff": 1.3887, "dither": 0, "energy_floor": 3.3198, "frame_length": 0.375, "frame_shift": 0.125, "preemphasis_coefficient": 0.14, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "window_type": "hanning"} +{"blackman_coeff": 1.4127, "dither": 0, "energy_floor": 2.6264, "frame_length": 0.875, "frame_shift": 0.375, "preemphasis_coefficient": 0.69, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "window_type": "povey"} +{"blackman_coeff": 1.5178, "dither": 0, "energy_floor": 2.8631, "frame_length": 1.0, "frame_shift": 0.8125, "preemphasis_coefficient": 0.95, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "window_type": "blackman"} +{"blackman_coeff": 1.5403, "dither": 0, "energy_floor": 0.0133, "frame_length": 1.1875, "frame_shift": 0.25, "preemphasis_coefficient": 0.59, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "window_type": "blackman"} +{"blackman_coeff": 1.5754, "dither": 0, "energy_floor": 0.954, "frame_length": 1.0, "frame_shift": 0.9375, "preemphasis_coefficient": 0.2, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "window_type": "rectangular"} +{"blackman_coeff": 1.5959, "dither": 0, "energy_floor": 0.9033, "frame_length": 0.75, "frame_shift": 1.0, "preemphasis_coefficient": 0.14, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "window_type": "hanning"} +{"blackman_coeff": 1.6923, "dither": 0, "energy_floor": 3.5626, "frame_length": 0.6875, "frame_shift": 1.0625, "preemphasis_coefficient": 0.27, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "window_type": "rectangular"} +{"blackman_coeff": 1.6972, "dither": 0, "energy_floor": 1.0863, "frame_length": 1.1875, "frame_shift": 0.875, "preemphasis_coefficient": 0.86, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "window_type": "rectangular"} +{"blackman_coeff": 1.744, "dither": 0, "energy_floor": 0.5308, "frame_length": 0.5, "frame_shift": 0.125, "preemphasis_coefficient": 0.33, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": false, "window_type": "hamming"} +{"blackman_coeff": 1.7642, "dither": 0, "energy_floor": 0.4833, "frame_length": 0.25, "frame_shift": 0.8125, "preemphasis_coefficient": 0.94, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": false, "window_type": "blackman"} +{"blackman_coeff": 1.8072, "dither": 0, "energy_floor": 0.8085, "frame_length": 0.5, "frame_shift": 0.25, "preemphasis_coefficient": 0.96, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "window_type": "hanning"} +{"blackman_coeff": 1.8836, "dither": 0, "energy_floor": 4.5145, "frame_length": 0.875, "frame_shift": 1.0625, "preemphasis_coefficient": 0.4, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "window_type": "hanning"} +{"blackman_coeff": 1.8946, "dither": 0, "energy_floor": 4.1442, "frame_length": 0.3125, "frame_shift": 0.875, "preemphasis_coefficient": 0.73, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "window_type": "rectangular"} +{"blackman_coeff": 1.8988, "dither": 0, "energy_floor": 3.0931, "frame_length": 1.0625, "frame_shift": 0.3125, "preemphasis_coefficient": 0.35, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "window_type": "hanning"} +{"blackman_coeff": 1.9501, "dither": 0, "energy_floor": 4.3519, "frame_length": 0.4375, "frame_shift": 0.25, "preemphasis_coefficient": 0.61, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "window_type": "hanning"} +{"blackman_coeff": 2.0137, "dither": 0, "energy_floor": 3.1007, "frame_length": 0.625, "frame_shift": 1.0625, "preemphasis_coefficient": 0.67, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "window_type": "rectangular"} +{"blackman_coeff": 2.0175, "dither": 0, "energy_floor": 2.9099, "frame_length": 1.0, "frame_shift": 0.5625, "preemphasis_coefficient": 0.28, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "window_type": "blackman"} +{"blackman_coeff": 2.1114, "dither": 0, "energy_floor": 4.5618, "frame_length": 0.25, "frame_shift": 0.875, "preemphasis_coefficient": 0.61, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": false, "window_type": "povey"} +{"blackman_coeff": 2.1472, "dither": 0, "energy_floor": 0.2, "frame_length": 1.125, "frame_shift": 0.875, "preemphasis_coefficient": 0.58, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": true, "window_type": "hamming"} +{"blackman_coeff": 2.1947, "dither": 0, "energy_floor": 1.8065, "frame_length": 0.875, "frame_shift": 0.75, "preemphasis_coefficient": 0.45, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": false, "window_type": "povey"} +{"blackman_coeff": 2.2457, "dither": 0, "energy_floor": 1.704, "frame_length": 0.75, "frame_shift": 0.5625, "preemphasis_coefficient": 0.98, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": true, "window_type": "hamming"} +{"blackman_coeff": 2.2893, "dither": 0, "energy_floor": 1.0286, "frame_length": 0.25, "frame_shift": 0.5, "preemphasis_coefficient": 0.8, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "window_type": "hanning"} +{"blackman_coeff": 2.3371, "dither": 0, "energy_floor": 4.4192, "frame_length": 0.8125, "frame_shift": 0.625, "preemphasis_coefficient": 0.3, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "window_type": "rectangular"} +{"blackman_coeff": 2.3831, "dither": 0, "energy_floor": 4.8325, "frame_length": 0.25, "frame_shift": 1.125, "preemphasis_coefficient": 0.34, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "window_type": "povey"} +{"blackman_coeff": 2.423, "dither": 0, "energy_floor": 0.6363, "frame_length": 0.875, "frame_shift": 0.3125, "preemphasis_coefficient": 0.77, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "window_type": "hanning"} +{"blackman_coeff": 2.4378, "dither": 0, "energy_floor": 1.4617, "frame_length": 0.9375, "frame_shift": 0.375, "preemphasis_coefficient": 0.53, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "window_type": "rectangular"} +{"blackman_coeff": 2.4454, "dither": 0, "energy_floor": 1.936, "frame_length": 1.0, "frame_shift": 0.9375, "preemphasis_coefficient": 0.66, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": false, "window_type": "rectangular"} +{"blackman_coeff": 2.448, "dither": 0, "energy_floor": 3.8782, "frame_length": 0.5625, "frame_shift": 1.125, "preemphasis_coefficient": 0.1, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "window_type": "rectangular"} +{"blackman_coeff": 2.5164, "dither": 0, "energy_floor": 2.7455, "frame_length": 0.875, "frame_shift": 0.9375, "preemphasis_coefficient": 0.55, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": false, "window_type": "hanning"} +{"blackman_coeff": 2.5316, "dither": 0, "energy_floor": 2.3286, "frame_length": 0.75, "frame_shift": 0.75, "preemphasis_coefficient": 0.61, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": true, "window_type": "rectangular"} +{"blackman_coeff": 2.5487, "dither": 0, "energy_floor": 3.8457, "frame_length": 1.1875, "frame_shift": 0.9375, "preemphasis_coefficient": 0.63, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "window_type": "hanning"} +{"blackman_coeff": 2.6121, "dither": 0, "energy_floor": 4.3165, "frame_length": 0.6875, "frame_shift": 1.1875, "preemphasis_coefficient": 0.19, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "window_type": "rectangular"} +{"blackman_coeff": 2.6988, "dither": 0, "energy_floor": 2.3417, "frame_length": 1.0, "frame_shift": 0.6875, "preemphasis_coefficient": 0.38, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "window_type": "rectangular"} +{"blackman_coeff": 2.7457, "dither": 0, "energy_floor": 1.3662, "frame_length": 0.25, "frame_shift": 0.875, "preemphasis_coefficient": 0.74, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": true, "window_type": "blackman"} +{"blackman_coeff": 2.8577, "dither": 0, "energy_floor": 4.1431, "frame_length": 0.375, "frame_shift": 1.0, "preemphasis_coefficient": 1.0, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "window_type": "hamming"} +{"blackman_coeff": 2.8693, "dither": 0, "energy_floor": 4.3801, "frame_length": 0.75, "frame_shift": 1.0, "preemphasis_coefficient": 0.95, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "window_type": "rectangular"} +{"blackman_coeff": 2.8888, "dither": 0, "energy_floor": 0.4078, "frame_length": 0.3125, "frame_shift": 0.625, "preemphasis_coefficient": 0.25, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "window_type": "hanning"} +{"blackman_coeff": 2.9074, "dither": 0, "energy_floor": 1.6849, "frame_length": 1.125, "frame_shift": 0.625, "preemphasis_coefficient": 0.79, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "window_type": "blackman"} +{"blackman_coeff": 2.9303, "dither": 0, "energy_floor": 3.5172, "frame_length": 0.5, "frame_shift": 0.5, "preemphasis_coefficient": 0.04, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "window_type": "rectangular"} +{"blackman_coeff": 3.07, "dither": 0, "energy_floor": 3.5254, "frame_length": 0.75, "frame_shift": 0.875, "preemphasis_coefficient": 0.96, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": false, "window_type": "povey"} +{"blackman_coeff": 3.1297, "dither": 0, "energy_floor": 0.3513, "frame_length": 0.4375, "frame_shift": 0.3125, "preemphasis_coefficient": 0.2, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "window_type": "povey"} +{"blackman_coeff": 3.2523, "dither": 0, "energy_floor": 3.5376, "frame_length": 0.3125, "frame_shift": 0.25, "preemphasis_coefficient": 0.46, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "window_type": "rectangular"} +{"blackman_coeff": 3.3896, "dither": 0, "energy_floor": 0.4666, "frame_length": 1.125, "frame_shift": 0.25, "preemphasis_coefficient": 0.05, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": false, "window_type": "rectangular"} +{"blackman_coeff": 3.537, "dither": 0, "energy_floor": 1.7032, "frame_length": 0.375, "frame_shift": 0.875, "preemphasis_coefficient": 0.17, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "window_type": "hamming"} +{"blackman_coeff": 3.5378, "dither": 0, "energy_floor": 3.6594, "frame_length": 0.25, "frame_shift": 0.625, "preemphasis_coefficient": 0.54, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "window_type": "rectangular"} +{"blackman_coeff": 3.5847, "dither": 0, "energy_floor": 3.6357, "frame_length": 1.0, "frame_shift": 0.3125, "preemphasis_coefficient": 0.79, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": false, "window_type": "povey"} +{"blackman_coeff": 3.6057, "dither": 0, "energy_floor": 1.6902, "frame_length": 1.0625, "frame_shift": 0.6875, "preemphasis_coefficient": 0.65, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "window_type": "blackman"} +{"blackman_coeff": 3.6498, "dither": 0, "energy_floor": 0.2005, "frame_length": 0.9375, "frame_shift": 1.125, "preemphasis_coefficient": 0.37, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "window_type": "rectangular"} +{"blackman_coeff": 3.6648, "dither": 0, "energy_floor": 4.6742, "frame_length": 0.625, "frame_shift": 1.1875, "preemphasis_coefficient": 0.88, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": true, "window_type": "blackman"} +{"blackman_coeff": 3.6701, "dither": 0, "energy_floor": 3.7451, "frame_length": 0.8125, "frame_shift": 0.25, "preemphasis_coefficient": 0.19, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "window_type": "blackman"} +{"blackman_coeff": 3.7232, "dither": 0, "energy_floor": 0.4912, "frame_length": 0.375, "frame_shift": 0.875, "preemphasis_coefficient": 0.34, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": true, "window_type": "rectangular"} +{"blackman_coeff": 3.7605, "dither": 0, "energy_floor": 1.6813, "frame_length": 0.25, "frame_shift": 0.5625, "preemphasis_coefficient": 0.27, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "window_type": "rectangular"} +{"blackman_coeff": 3.7759, "dither": 0, "energy_floor": 1.7002, "frame_length": 1.0625, "frame_shift": 0.6875, "preemphasis_coefficient": 0.42, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "window_type": "hamming"} +{"blackman_coeff": 3.7921, "dither": 0, "energy_floor": 3.4087, "frame_length": 0.25, "frame_shift": 1.0, "preemphasis_coefficient": 0.54, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "window_type": "blackman"} +{"blackman_coeff": 3.7954, "dither": 0, "energy_floor": 3.5651, "frame_length": 0.5, "frame_shift": 0.8125, "preemphasis_coefficient": 0.06, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "window_type": "blackman"} +{"blackman_coeff": 3.799, "dither": 0, "energy_floor": 3.0026, "frame_length": 0.625, "frame_shift": 1.0, "preemphasis_coefficient": 0.82, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": true, "window_type": "hamming"} +{"blackman_coeff": 3.8659, "dither": 0, "energy_floor": 1.7487, "frame_length": 1.1875, "frame_shift": 0.375, "preemphasis_coefficient": 1.0, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "window_type": "hanning"} +{"blackman_coeff": 3.951, "dither": 0, "energy_floor": 0.3903, "frame_length": 1.125, "frame_shift": 1.0, "preemphasis_coefficient": 0.41, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": true, "window_type": "rectangular"} +{"blackman_coeff": 4.0045, "dither": 0, "energy_floor": 3.061, "frame_length": 0.625, "frame_shift": 1.0625, "preemphasis_coefficient": 0.74, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "window_type": "hanning"} +{"blackman_coeff": 4.0187, "dither": 0, "energy_floor": 4.8148, "frame_length": 0.375, "frame_shift": 0.6875, "preemphasis_coefficient": 0.68, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": false, "window_type": "hanning"} +{"blackman_coeff": 4.032, "dither": 0, "energy_floor": 2.2019, "frame_length": 1.125, "frame_shift": 0.25, "preemphasis_coefficient": 0.78, "raw_energy": true, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "window_type": "rectangular"} +{"blackman_coeff": 4.0627, "dither": 0, "energy_floor": 4.1729, "frame_length": 0.625, "frame_shift": 1.125, "preemphasis_coefficient": 0.89, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": false, "window_type": "povey"} +{"blackman_coeff": 4.0736, "dither": 0, "energy_floor": 0.9155, "frame_length": 1.0625, "frame_shift": 0.5625, "preemphasis_coefficient": 0.82, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "window_type": "hamming"} +{"blackman_coeff": 4.1131, "dither": 0, "energy_floor": 3.9204, "frame_length": 0.5, "frame_shift": 0.125, "preemphasis_coefficient": 0.39, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "window_type": "hanning"} +{"blackman_coeff": 4.1816, "dither": 0, "energy_floor": 1.665, "frame_length": 0.8125, "frame_shift": 0.375, "preemphasis_coefficient": 0.37, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "window_type": "rectangular"} +{"blackman_coeff": 4.1897, "dither": 0, "energy_floor": 1.2668, "frame_length": 0.1875, "frame_shift": 0.625, "preemphasis_coefficient": 0.74, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "window_type": "hamming"} +{"blackman_coeff": 4.2217, "dither": 0, "energy_floor": 3.6775, "frame_length": 0.3125, "frame_shift": 0.125, "preemphasis_coefficient": 0.01, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "window_type": "hamming"} +{"blackman_coeff": 4.2785, "dither": 0, "energy_floor": 0.7201, "frame_length": 0.8125, "frame_shift": 0.8125, "preemphasis_coefficient": 0.3, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "window_type": "hanning"} +{"blackman_coeff": 4.3304, "dither": 0, "energy_floor": 1.0538, "frame_length": 0.875, "frame_shift": 1.125, "preemphasis_coefficient": 0.92, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": false, "window_type": "povey"} +{"blackman_coeff": 4.3942, "dither": 0, "energy_floor": 3.9813, "frame_length": 0.75, "frame_shift": 0.6875, "preemphasis_coefficient": 0.27, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "window_type": "blackman"} +{"blackman_coeff": 4.4432, "dither": 0, "energy_floor": 2.0441, "frame_length": 0.5, "frame_shift": 0.6875, "preemphasis_coefficient": 0.77, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "window_type": "hanning"} +{"blackman_coeff": 4.4459, "dither": 0, "energy_floor": 0.5135, "frame_length": 0.25, "frame_shift": 0.1875, "preemphasis_coefficient": 0.29, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "window_type": "hanning"} +{"blackman_coeff": 4.5486, "dither": 0, "energy_floor": 1.3248, "frame_length": 0.1875, "frame_shift": 1.125, "preemphasis_coefficient": 0.91, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": false, "subtract_mean": true, "window_type": "hanning"} +{"blackman_coeff": 4.5535, "dither": 0, "energy_floor": 2.1772, "frame_length": 0.4375, "frame_shift": 0.875, "preemphasis_coefficient": 0.21, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "window_type": "hanning"} +{"blackman_coeff": 4.5835, "dither": 0, "energy_floor": 0.3781, "frame_length": 0.875, "frame_shift": 0.875, "preemphasis_coefficient": 0.04, "raw_energy": true, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": true, "window_type": "hamming"} +{"blackman_coeff": 4.6297, "dither": 0, "energy_floor": 2.49, "frame_length": 0.5, "frame_shift": 0.25, "preemphasis_coefficient": 0.03, "raw_energy": false, "remove_dc_offset": false, "round_to_power_of_two": true, "snip_edges": true, "subtract_mean": false, "window_type": "rectangular"} +{"blackman_coeff": 4.6749, "dither": 0, "energy_floor": 4.8853, "frame_length": 0.25, "frame_shift": 0.25, "preemphasis_coefficient": 0.48, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": false, "subtract_mean": false, "window_type": "povey"} +{"blackman_coeff": 4.6971, "dither": 0, "energy_floor": 1.3632, "frame_length": 0.875, "frame_shift": 0.9375, "preemphasis_coefficient": 0.44, "raw_energy": false, "remove_dc_offset": true, "round_to_power_of_two": false, "snip_edges": true, "subtract_mean": false, "window_type": "blackman"} diff --git a/test/torchaudio_unittest/assets/mat.ark b/test/torchaudio_unittest/assets/mat.ark new file mode 100644 index 0000000000000000000000000000000000000000..d50ef9d5797b76878b01b21c7f94b7f7e39b3568 Binary files /dev/null and b/test/torchaudio_unittest/assets/mat.ark differ diff --git a/test/torchaudio_unittest/assets/mp3_without_ext b/test/torchaudio_unittest/assets/mp3_without_ext new file mode 100644 index 0000000000000000000000000000000000000000..e4d4e6973e1cd0d8262ad3c88edbcc0464bf4663 Binary files /dev/null and b/test/torchaudio_unittest/assets/mp3_without_ext differ diff --git a/test/torchaudio_unittest/assets/sinewave.wav b/test/torchaudio_unittest/assets/sinewave.wav new file mode 100644 index 0000000000000000000000000000000000000000..93182c4eba4e6349555d773fecfdefb5a47e0926 Binary files /dev/null and b/test/torchaudio_unittest/assets/sinewave.wav differ diff --git a/test/torchaudio_unittest/assets/sox_effect_test_args.jsonl b/test/torchaudio_unittest/assets/sox_effect_test_args.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..2a223df63542e484f15bcc8be6f8b2592295173f --- /dev/null +++ b/test/torchaudio_unittest/assets/sox_effect_test_args.jsonl @@ -0,0 +1,88 @@ +{"effects": [["allpass", "300", "10"]]} +{"effects": [["band", "300", "10"]]} +{"effects": [["bandpass", "300", "10"]]} +{"effects": [["bandreject", "300", "10"]]} +{"effects": [["bass", "-10"]]} +{"effects": [["bend", ".35,180,.25", ".15,740,.53", "0,-520,.3"]]} +{"effects": [["biquad", "0.4", "0.2", "0.9", "0.7", "0.2", "0.6"]]} +{"effects": [["chorus", "0.7", "0.9", "55", "0.4", "0.25", "2", "-t"]]} +{"effects": [["chorus", "0.6", "0.9", "50", "0.4", "0.25", "2", "-t", "60", "0.32", "0.4", "1.3", "-s"]]} +{"effects": [["chorus", "0.5", "0.9", "50", "0.4", "0.25", "2", "-t", "60", "0.32", "0.4", "2.3", "-t", "40", "0.3", "0.3", "1.3", "-s"]]} +{"effects": [["channels", "1"]]} +{"effects": [["channels", "2"]]} +{"effects": [["channels", "3"]]} +{"effects": [["compand", "0.3,1", "6:-70,-60,-20", "-5", "-90", "0.2"]]} +{"effects": [["compand", ".1,.2", "-inf,-50.1,-inf,-50,-50", "0", "-90", ".1"]]} +{"effects": [["compand", ".1,.1", "-45.1,-45,-inf,0,-inf", "45", "-90", ".1"]]} +{"effects": [["contrast", "0"]]} +{"effects": [["contrast", "25"]]} +{"effects": [["contrast", "50"]]} +{"effects": [["contrast", "75"]]} +{"effects": [["contrast", "100"]]} +{"effects": [["dcshift", "1.0"]]} +{"effects": [["dcshift", "-1.0"]]} +{"effects": [["deemph"]], "input_sample_rate": 44100} +{"effects": [["delay", "1.5", "+1"]]} +{"effects": [["dither", "-s"]]} +{"effects": [["dither", "-S"]]} +{"effects": [["divide"]]} +{"effects": [["downsample", "2"]], "input_sample_rate": 8000, "output_sample_rate": 4000} +{"effects": [["earwax"]], "input_sample_rate": 44100} +{"effects": [["echo", "0.8", "0.88", "60", "0.4"]]} +{"effects": [["echo", "0.8", "0.88", "6", "0.4"]]} +{"effects": [["echo", "0.8", "0.9", "1000", "0.3"]]} +{"effects": [["echo", "0.8", "0.9", "1000", "0.3", "1800", "0.25"]]} +{"effects": [["echos", "0.8", "0.7", "700", "0.25", "700", "0.3"]]} +{"effects": [["echos", "0.8", "0.7", "700", "0.25", "900", "0.3"]]} +{"effects": [["echos", "0.8", "0.7", "40", "0.25", "63", "0.3"]]} +{"effects": [["equalizer", "300", "10", "5"]]} +{"effects": [["fade", "q", "3"]]} +{"effects": [["fade", "h", "3"]]} +{"effects": [["fade", "t", "3"]]} +{"effects": [["fade", "l", "3"]]} +{"effects": [["fade", "p", "3"]]} +{"effects": [["fir", "0.0195", "-0.082", "0.234", "0.891", "-0.145", "0.043"]]} +{"effects": [["fir", "/sox_effect_test_fir_coeffs.txt"]]} +{"effects": [["flanger"]]} +{"effects": [["gain", "-n"]]} +{"effects": [["gain", "-n", "-3"]]} +{"effects": [["gain", "-l", "-6"]]} +{"effects": [["highpass", "-1", "300"]]} +{"effects": [["highpass", "-2", "300"]]} +{"effects": [["hilbert"]]} +{"effects": [["loudness"]]} +{"effects": [["lowpass", "-1", "300"]]} +{"effects": [["lowpass", "-2", "300"]]} +{"effects": [["mcompand", "0.005,0.1 -47,-40,-34,-34,-17,-33", "100", "0.003,0.05 -47,-40,-34,-34,-17,-33", "400", "0.000625,0.0125 -47,-40,-34,-34,-15,-33", "1600", "0.0001,0.025 -47,-40,-34,-34,-31,-31,-0,-30", "6400", "0,0.025 -38,-31,-28,-28,-0,-25"]], "input_sample_rate": 44100} +{"effects": [["norm"]]} +{"effects": [["oops"]]} +{"effects": [["overdrive"]]} +{"effects": [["pad"]]} +{"effects": [["phaser"]]} +{"effects": [["pitch", "6.48"], ["rate", "8030"]], "output_sample_rate": 8030} +{"effects": [["pitch", "-6.50"], ["rate", "7970"]], "output_sample_rate": 7970} +{"effects": [["rate", "4567"]], "output_sample_rate": 4567} +{"effects": [["remix", "6", "7", "8", "0"]], "num_channels": 8} +{"effects": [["remix", "1-3,7", "3"]], "num_channels": 8} +{"effects": [["repeat"]]} +{"effects": [["reverb"]]} +{"effects": [["reverse"]]} +{"effects": [["riaa"]], "input_sample_rate": 44100} +{"effects": [["silence", "0"]]} +{"effects": [["sinc", "3k"]]} +{"effects": [["speed", "1.3"]], "input_sample_rate": 4000, "output_sample_rate": 5200} +{"effects": [["speed", "0.7"]], "input_sample_rate": 4000, "output_sample_rate": 2800} +{"effects": [["stat"]]} +{"effects": [["stats"]]} +{"effects": [["stretch"]]} +{"effects": [["swap"]]} +{"effects": [["synth"]]} +{"effects": [["tempo", "0.9"]]} +{"effects": [["tempo", "1.1"]]} +{"effects": [["treble", "3"]]} +{"effects": [["tremolo", "300", "40"]]} +{"effects": [["tremolo", "300", "50"]]} +{"effects": [["trim", "0", "0.1"]]} +{"effects": [["upsample", "2"]], "input_sample_rate": 8000, "output_sample_rate": 16000} +{"effects": [["vad"]]} +{"effects": [["vol", "3"]]} diff --git a/test/torchaudio_unittest/assets/sox_effect_test_fir_coeffs.txt b/test/torchaudio_unittest/assets/sox_effect_test_fir_coeffs.txt new file mode 100644 index 0000000000000000000000000000000000000000..903a607d3bb081a3add904c496dd01dcfd2e52e2 --- /dev/null +++ b/test/torchaudio_unittest/assets/sox_effect_test_fir_coeffs.txt @@ -0,0 +1 @@ +0.0195 -0.082 0.234 0.891 -0.145 0.043 diff --git a/test/torchaudio_unittest/assets/steam-train-whistle-daniel_simon.mp3 b/test/torchaudio_unittest/assets/steam-train-whistle-daniel_simon.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..8977e7247c21f587c1e4a52f7eab87935c28c33e Binary files /dev/null and b/test/torchaudio_unittest/assets/steam-train-whistle-daniel_simon.mp3 differ diff --git a/test/torchaudio_unittest/assets/steam-train-whistle-daniel_simon.wav b/test/torchaudio_unittest/assets/steam-train-whistle-daniel_simon.wav new file mode 100644 index 0000000000000000000000000000000000000000..773dc0f79c12e6b94d8ba0e6fd21ba0238ae3c4a Binary files /dev/null and b/test/torchaudio_unittest/assets/steam-train-whistle-daniel_simon.wav differ diff --git a/test/torchaudio_unittest/assets/vad-go-mono-32000.wav b/test/torchaudio_unittest/assets/vad-go-mono-32000.wav new file mode 100644 index 0000000000000000000000000000000000000000..277d4bb0b9b98ff6782344eb1c4cc2f056fac58b Binary files /dev/null and b/test/torchaudio_unittest/assets/vad-go-mono-32000.wav differ diff --git a/test/torchaudio_unittest/assets/vad-go-stereo-44100.wav b/test/torchaudio_unittest/assets/vad-go-stereo-44100.wav new file mode 100644 index 0000000000000000000000000000000000000000..107de5ac6934f3d52b64da3f27e489ab6d1f04fd Binary files /dev/null and b/test/torchaudio_unittest/assets/vad-go-stereo-44100.wav differ diff --git a/test/torchaudio_unittest/assets/vec_flt.ark b/test/torchaudio_unittest/assets/vec_flt.ark new file mode 100644 index 0000000000000000000000000000000000000000..ed74a0b2859aaf888099e9b2defe417462946102 Binary files /dev/null and b/test/torchaudio_unittest/assets/vec_flt.ark differ diff --git a/test/torchaudio_unittest/assets/vec_int.ark b/test/torchaudio_unittest/assets/vec_int.ark new file mode 100644 index 0000000000000000000000000000000000000000..d8892ad47a33155619c66eba9e8f0d6df6d1caee Binary files /dev/null and b/test/torchaudio_unittest/assets/vec_int.ark differ diff --git a/test/torchaudio_unittest/assets/wav2vec2/fairseq/generate_hubert_model_config.py b/test/torchaudio_unittest/assets/wav2vec2/fairseq/generate_hubert_model_config.py new file mode 100644 index 0000000000000000000000000000000000000000..28bf3349bf5cbf71885da6ba5d39c38f05363b1b --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/fairseq/generate_hubert_model_config.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +"""Generate the conf JSONs from fairseq pretrained weight file, consumed by unit tests + +Note: + The current configuration files were generated on fairseq e47a4c84 + +Usage: +1. Download pretrained parameters from https://github.com/pytorch/fairseq/tree/main/examples/hubert +2. Run this script and save the resulting JSON configuration in assets directory. + +Example: + +``` +python generate_hubert_model_config.py \ + --model-file hubert_base_ls960.pt \ + > hubert_base_ls960.json + +python generate_hubert_model_config.py \ + --model-file hubert_large_ll60k.pt \ + > hubert_large_ll60k.json + +python generate_hubert_model_config.py \ + --model-file hubert_large_ll60k_finetune_ls960.pt \ + > hubert_large_ll60k_finetune_ls960.json + +python generate_hubert_model_config.py \ + --model-file hubert_xlarge_ll60k.pt \ + > hubert_large_ll60k.json + +python generate_hubert_model_config.py \ + --model-file hubert_xlarge_ll60k_finetune_ls960.pt \ + > hubert_large_ll60k_finetune_ls960.json +``` +""" +import json +import argparse + + +def _parse_args(): + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + '--model-file', + required=True, + help=( + 'A pt file from ' + 'https://github.com/pytorch/fairseq/tree/main/examples/hubert' + ) + ) + return parser.parse_args() + + +def _load(model_file): + import fairseq + from omegaconf import OmegaConf + + models, cfg, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file]) + model = models[0] + cfg = OmegaConf.to_container(cfg) + return model, cfg + + +def _main(): + args = _parse_args() + model, cfg = _load(args.model_file) + + if model.__class__.__name__ == 'HubertModel': + cfg['task']['data'] = '/foo/bar' + cfg['task']['label_dir'] = None + conf = { + '_name': 'hubert', + 'model': cfg['model'], + 'task': cfg['task'], + 'num_classes': model.num_classes, + } + elif model.__class__.__name__ == 'HubertCtc': + conf = cfg['model'] + del conf['w2v_path'] + keep = ['_name', 'task', 'model'] + for key in list(k for k in conf['w2v_args'] if k not in keep): + del conf['w2v_args'][key] + conf['data'] = '/foo/bar/' + conf['w2v_args']['task']['data'] = '/foo/bar' + conf['w2v_args']['task']['labels'] = [] + conf['w2v_args']['task']['label_dir'] = '/foo/bar' + print(json.dumps(conf, indent=4, sort_keys=True)) + + +if __name__ == '__main__': + _main() diff --git a/test/torchaudio_unittest/assets/wav2vec2/fairseq/generate_wav2vec2_model_config.py b/test/torchaudio_unittest/assets/wav2vec2/fairseq/generate_wav2vec2_model_config.py new file mode 100644 index 0000000000000000000000000000000000000000..2d651a1b77d094d7b565fcd55bf317f83c69b4ac --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/fairseq/generate_wav2vec2_model_config.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +"""Generate the conf JSON from fairseq pretrained weight file, that is consumed by unit tests + +Usage: +1. Download pretrained parameters from https://github.com/pytorch/fairseq/tree/main/examples/wav2vec +2. Download the dict from https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt + and put it in the same directory as parameter files. +3. Run this script and save the resulting JSON configuration in assets directory. + +Example: + +``` +# Pretrained +python generate_wav2vec2_model_config.py \ + --model-file wav2vec_small.pt \ + > wav2vec_small.json + +python generate_wav2vec2_model_config.py \ + --model-file libri960_big.pt \ + > libri960_big.json + +python generate_wav2vec2_model_config.py \ + --model-file wav2vec_vox_new.pt \ + > wav2vec_vox_new.json + +# Fine-tuned +python generate_wav2vec2_model_config.py \ + --model-file wav2vec_small_960h.pt \ + > wav2vec_small_960h.json + +python generate_wav2vec2_model_config.py \ + --model-file wav2vec_big_960h.pt \ + > wav2vec_large_960h.json + +python generate_wav2vec2_model_config.py \ + --model-file wav2vec2_vox_960h_new.pt \ + > wav2vec_large_lv60_960h.json + +python generate_wav2vec2_model_config.py \ + --model-file wav2vec_vox_960h_pl.pt \ + > wav2vec_large_lv60_self_960h.json +``` +""" +import os +import json +import argparse + + +def _parse_args(): + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + '--model-file', + required=True, + help=( + 'A point file from ' + 'https://github.com/pytorch/fairseq/tree/main/examples/wav2vec' + ) + ) + parser.add_argument( + '--dict-dir', + help=( + 'Directory where `dict.ltr.txt` file is found. ' + 'Default: the directory of the given model.' + ) + ) + args = parser.parse_args() + if args.dict_dir is None: + args.dict_dir = os.path.dirname(args.model_file) + return args + + +def _to_json(conf): + import yaml + from omegaconf import OmegaConf + return yaml.safe_load(OmegaConf.to_yaml(conf)) + + +def _load(model_file, dict_dir): + import fairseq + + overrides = {'data': dict_dir} + _, args, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( + [model_file], arg_overrides=overrides + ) + return _to_json(args['model']) + + +def _main(): + args = _parse_args() + conf = _load(args.model_file, args.dict_dir) + + if conf['_name'] == 'wav2vec_ctc': + del conf['data'] + del conf['w2v_args']['task']['data'] + conf['w2v_args'] = { + key: conf['w2v_args'][key] for key in ['model', 'task'] + } + + print(json.dumps(conf, indent=4, sort_keys=True)) + + +if __name__ == '__main__': + _main() diff --git a/test/torchaudio_unittest/assets/wav2vec2/fairseq/hubert_base_ls960.json b/test/torchaudio_unittest/assets/wav2vec2/fairseq/hubert_base_ls960.json new file mode 100644 index 0000000000000000000000000000000000000000..7c6d7ad30d8fe4334628a19d279c133744993593 --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/fairseq/hubert_base_ls960.json @@ -0,0 +1,69 @@ +{ + "_name": "hubert", + "model": { + "_name": "hubert", + "activation_dropout": 0.0, + "activation_fn": "gelu", + "attention_dropout": 0.1, + "conv_bias": false, + "conv_feature_layers": "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", + "conv_pos": 128, + "conv_pos_groups": 16, + "dropout": 0.1, + "dropout_features": 0.1, + "dropout_input": 0.1, + "encoder_attention_heads": 12, + "encoder_embed_dim": 768, + "encoder_ffn_embed_dim": 3072, + "encoder_layerdrop": 0.05, + "encoder_layers": 12, + "extractor_mode": "default", + "feature_grad_mult": 0.1, + "final_dim": 256, + "label_rate": 50, + "latent_temp": [ + 2.0, + 0.5, + 0.999995 + ], + "layer_norm_first": false, + "logit_temp": 0.1, + "mask_channel_length": 10, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.0, + "mask_channel_selection": "static", + "mask_length": 10, + "mask_min_space": 1, + "mask_other": 0.0, + "mask_prob": 0.8, + "mask_selection": "static", + "no_mask_channel_overlap": false, + "no_mask_overlap": false, + "skip_masked": false, + "skip_nomask": false, + "target_glu": false, + "untie_final_proj": false + }, + "num_classes": [ + 504 + ], + "task": { + "_name": "hubert_pretraining", + "data": "/foo/bar", + "enable_padding": false, + "fine_tuning": false, + "label_dir": null, + "label_rate": 50, + "labels": [ + "layer6.km500" + ], + "max_sample_size": 250000, + "min_sample_size": 32000, + "normalize": false, + "pad_audio": false, + "random_crop": true, + "sample_rate": 16000, + "single_target": false + } +} diff --git a/test/torchaudio_unittest/assets/wav2vec2/fairseq/hubert_large_ll60k.json b/test/torchaudio_unittest/assets/wav2vec2/fairseq/hubert_large_ll60k.json new file mode 100644 index 0000000000000000000000000000000000000000..a1b1481020b7b67dcfeffce7cc26accbb952625e --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/fairseq/hubert_large_ll60k.json @@ -0,0 +1,68 @@ +{ + "_name": "hubert", + "model": { + "_name": "hubert", + "activation_dropout": 0.0, + "activation_fn": "gelu", + "attention_dropout": 0.0, + "conv_bias": false, + "conv_feature_layers": "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", + "conv_pos": 128, + "conv_pos_groups": 16, + "dropout": 0.0, + "dropout_features": 0.0, + "dropout_input": 0.0, + "encoder_attention_heads": 16, + "encoder_embed_dim": 1024, + "encoder_ffn_embed_dim": 4096, + "encoder_layerdrop": 0.0, + "encoder_layers": 24, + "extractor_mode": "layer_norm", + "feature_grad_mult": 1.0, + "final_dim": 768, + "label_rate": 50, + "latent_temp": [ + 2.0, + 0.5, + 0.999995 + ], + "layer_norm_first": true, + "logit_temp": 0.1, + "mask_channel_length": 10, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.0, + "mask_channel_selection": "static", + "mask_length": 10, + "mask_min_space": 1, + "mask_other": 0.0, + "mask_prob": 0.8, + "mask_selection": "static", + "no_mask_channel_overlap": false, + "no_mask_overlap": false, + "skip_masked": false, + "skip_nomask": true, + "target_glu": false, + "untie_final_proj": true + }, + "num_classes": [ + 504 + ], + "task": { + "_name": "hubert_pretraining", + "data": "/foo/bar", + "enable_padding": false, + "label_dir": null, + "label_rate": 50, + "labels": [ + "lyr9.km500" + ], + "max_sample_size": 250000, + "min_sample_size": 32000, + "normalize": true, + "pad_audio": false, + "random_crop": true, + "sample_rate": 16000, + "single_target": false + } +} diff --git a/test/torchaudio_unittest/assets/wav2vec2/fairseq/hubert_large_ll60k_finetune_ls960.json b/test/torchaudio_unittest/assets/wav2vec2/fairseq/hubert_large_ll60k_finetune_ls960.json new file mode 100644 index 0000000000000000000000000000000000000000..c50c8f1a3d1f24a8d13177b46287f4efa69e8a1a --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/fairseq/hubert_large_ll60k_finetune_ls960.json @@ -0,0 +1,89 @@ +{ + "_name": "hubert_ctc", + "activation_dropout": 0.1, + "apply_mask": true, + "attention_dropout": 0.0, + "data": "/foo/bar/", + "dropout": 0.0, + "dropout_input": 0.0, + "feature_grad_mult": 0.0, + "final_dropout": 0.0, + "freeze_finetune_updates": 10000, + "layerdrop": 0.1, + "mask_channel_length": 64, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.25, + "mask_channel_selection": "static", + "mask_length": 10, + "mask_other": 0.0, + "mask_prob": 0.5, + "mask_selection": "static", + "no_mask_channel_overlap": false, + "no_mask_overlap": false, + "no_pretrained_weights": false, + "normalize": true, + "w2v_args": { + "_name": null, + "model": { + "_name": "hubert", + "activation_dropout": 0.1, + "activation_fn": "gelu", + "attention_dropout": 0.0, + "conv_bias": false, + "conv_feature_layers": "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", + "conv_pos": 128, + "conv_pos_groups": 16, + "dropout": 0.0, + "dropout_features": 0.0, + "dropout_input": 0.0, + "encoder_attention_heads": 16, + "encoder_embed_dim": 1024, + "encoder_ffn_embed_dim": 4096, + "encoder_layerdrop": 0.1, + "encoder_layers": 24, + "extractor_mode": "layer_norm", + "feature_grad_mult": 0.0, + "final_dim": 768, + "label_rate": 50, + "latent_temp": [ + 2.0, + 0.5, + 0.999995 + ], + "layer_norm_first": true, + "logit_temp": 0.1, + "mask_channel_length": 64, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.25, + "mask_channel_selection": "static", + "mask_length": 10, + "mask_min_space": 1, + "mask_other": 0.0, + "mask_prob": 0.5, + "mask_selection": "static", + "no_mask_channel_overlap": false, + "no_mask_overlap": false, + "skip_masked": false, + "skip_nomask": true, + "target_glu": false, + "untie_final_proj": true + }, + "task": { + "_name": "hubert_pretraining", + "data": "/foo/bar", + "enable_padding": false, + "fine_tuning": false, + "label_dir": "/foo/bar", + "label_rate": 50, + "labels": [], + "max_sample_size": 250000, + "min_sample_size": 32000, + "normalize": true, + "pad_audio": false, + "random_crop": true, + "sample_rate": 16000, + "single_target": false + } + } +} diff --git a/test/torchaudio_unittest/assets/wav2vec2/fairseq/hubert_xtralarge_ll60k.json b/test/torchaudio_unittest/assets/wav2vec2/fairseq/hubert_xtralarge_ll60k.json new file mode 100644 index 0000000000000000000000000000000000000000..2ade77dd413cb71ea32becc78d7797fd2dfcd287 --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/fairseq/hubert_xtralarge_ll60k.json @@ -0,0 +1,68 @@ +{ + "_name": "hubert", + "model": { + "_name": "hubert", + "activation_dropout": 0.0, + "activation_fn": "gelu", + "attention_dropout": 0.0, + "conv_bias": false, + "conv_feature_layers": "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", + "conv_pos": 128, + "conv_pos_groups": 16, + "dropout": 0.0, + "dropout_features": 0.0, + "dropout_input": 0.0, + "encoder_attention_heads": 16, + "encoder_embed_dim": 1280, + "encoder_ffn_embed_dim": 5120, + "encoder_layerdrop": 0.0, + "encoder_layers": 48, + "extractor_mode": "layer_norm", + "feature_grad_mult": 1.0, + "final_dim": 1024, + "label_rate": 50, + "latent_temp": [ + 2.0, + 0.5, + 0.999995 + ], + "layer_norm_first": true, + "logit_temp": 0.1, + "mask_channel_length": 10, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.0, + "mask_channel_selection": "static", + "mask_length": 10, + "mask_min_space": 1, + "mask_other": 0.0, + "mask_prob": 0.8, + "mask_selection": "static", + "no_mask_channel_overlap": false, + "no_mask_overlap": false, + "skip_masked": false, + "skip_nomask": true, + "target_glu": false, + "untie_final_proj": true + }, + "num_classes": [ + 504 + ], + "task": { + "_name": "hubert_pretraining", + "data": "/foo/bar", + "enable_padding": false, + "label_dir": null, + "label_rate": 50, + "labels": [ + "lyr9.km500" + ], + "max_sample_size": 250000, + "min_sample_size": 32000, + "normalize": true, + "pad_audio": false, + "random_crop": true, + "sample_rate": 16000, + "single_target": false + } +} diff --git a/test/torchaudio_unittest/assets/wav2vec2/fairseq/hubert_xtralarge_ll60k_finetune_ls960.json b/test/torchaudio_unittest/assets/wav2vec2/fairseq/hubert_xtralarge_ll60k_finetune_ls960.json new file mode 100644 index 0000000000000000000000000000000000000000..9a830702037b313e4dd51758f8c14e9e1f7b70d3 --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/fairseq/hubert_xtralarge_ll60k_finetune_ls960.json @@ -0,0 +1,89 @@ +{ + "_name": "hubert_ctc", + "activation_dropout": 0.1, + "apply_mask": true, + "attention_dropout": 0.0, + "data": "/foo/bar/", + "dropout": 0.0, + "dropout_input": 0.0, + "feature_grad_mult": 0.0, + "final_dropout": 0.0, + "freeze_finetune_updates": 10000, + "layerdrop": 0.1, + "mask_channel_length": 64, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.25, + "mask_channel_selection": "static", + "mask_length": 10, + "mask_other": 0.0, + "mask_prob": 0.5, + "mask_selection": "static", + "no_mask_channel_overlap": false, + "no_mask_overlap": false, + "no_pretrained_weights": false, + "normalize": true, + "w2v_args": { + "_name": null, + "model": { + "_name": "hubert", + "activation_dropout": 0.1, + "activation_fn": "gelu", + "attention_dropout": 0.0, + "conv_bias": false, + "conv_feature_layers": "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", + "conv_pos": 128, + "conv_pos_groups": 16, + "dropout": 0.0, + "dropout_features": 0.0, + "dropout_input": 0.0, + "encoder_attention_heads": 16, + "encoder_embed_dim": 1280, + "encoder_ffn_embed_dim": 5120, + "encoder_layerdrop": 0.1, + "encoder_layers": 48, + "extractor_mode": "layer_norm", + "feature_grad_mult": 0.0, + "final_dim": 1024, + "label_rate": 50, + "latent_temp": [ + 2.0, + 0.5, + 0.999995 + ], + "layer_norm_first": true, + "logit_temp": 0.1, + "mask_channel_length": 64, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.25, + "mask_channel_selection": "static", + "mask_length": 10, + "mask_min_space": 1, + "mask_other": 0.0, + "mask_prob": 0.5, + "mask_selection": "static", + "no_mask_channel_overlap": false, + "no_mask_overlap": false, + "skip_masked": false, + "skip_nomask": true, + "target_glu": false, + "untie_final_proj": true + }, + "task": { + "_name": "hubert_pretraining", + "data": "/foo/bar", + "enable_padding": false, + "fine_tuning": false, + "label_dir": "/foo/bar", + "label_rate": 50, + "labels": [], + "max_sample_size": 250000, + "min_sample_size": 32000, + "normalize": true, + "pad_audio": false, + "random_crop": true, + "sample_rate": 16000, + "single_target": false + } + } +} diff --git a/test/torchaudio_unittest/assets/wav2vec2/fairseq/libri960_big.json b/test/torchaudio_unittest/assets/wav2vec2/fairseq/libri960_big.json new file mode 100644 index 0000000000000000000000000000000000000000..3add61134ef6b5d8a2a030c871e959f5331ec8a3 --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/fairseq/libri960_big.json @@ -0,0 +1,54 @@ +{ + "_name": "wav2vec2", + "activation_dropout": 0.0, + "activation_fn": "gelu", + "attention_dropout": 0.1, + "codebook_negatives": 0, + "conv_bias": false, + "conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2", + "conv_pos": 128, + "conv_pos_groups": 16, + "cross_sample_negatives": 0, + "dropout": 0.0, + "dropout_features": 0.1, + "dropout_input": 0.1, + "encoder_attention_heads": 16, + "encoder_embed_dim": 1024, + "encoder_ffn_embed_dim": 4096, + "encoder_layerdrop": 0.2, + "encoder_layers": 24, + "extractor_mode": "default", + "feature_grad_mult": 0.1, + "final_dim": 768, + "latent_dim": 0, + "latent_groups": 2, + "latent_temp": [ + 2.0, + 0.5, + 0.999995 + ], + "latent_vars": 320, + "layer_norm_first": false, + "logit_temp": 0.1, + "mask_channel_before": false, + "mask_channel_length": 10, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.0, + "mask_channel_selection": "static", + "mask_length": 10, + "mask_min_space": 1, + "mask_other": 0.0, + "mask_prob": 0.65, + "mask_selection": "static", + "negatives_from_everywhere": false, + "no_mask_channel_overlap": false, + "no_mask_overlap": false, + "num_negatives": 100, + "quantize_input": false, + "quantize_targets": true, + "quantizer_depth": 1, + "quantizer_factor": 3, + "same_quantizer": false, + "target_glu": false +} diff --git a/test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_large_960h.json b/test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_large_960h.json new file mode 100644 index 0000000000000000000000000000000000000000..d46b38eeafa1db5c11a92e81a9726582094cd498 --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_large_960h.json @@ -0,0 +1,146 @@ +{ + "_name": "wav2vec_ctc", + "activation_dropout": 0.1, + "apply_mask": true, + "attention_dropout": 0.0, + "blank_mode": "add", + "blank_weight": 0.0, + "conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]", + "dropout": 0.0, + "dropout_input": 0.0, + "encoder_embed_dim": 512, + "feature_grad_mult": 0.0, + "final_dropout": 0.0, + "freeze_finetune_updates": 10000, + "layerdrop": 0.2, + "mask_channel_before": false, + "mask_channel_length": 64, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.1, + "mask_channel_selection": "static", + "mask_length": 10, + "mask_min_space": 1, + "mask_other": 0.0, + "mask_prob": 0.5, + "mask_selection": "static", + "no_mask_channel_overlap": false, + "no_mask_overlap": false, + "no_pretrained_weights": false, + "normalize": false, + "w2v_args": { + "model": { + "_name": "wav2vec2", + "activation_dropout": 0.0, + "activation_fn": "gelu", + "attention_dropout": 0.1, + "codebook_negatives": 0, + "conv_bias": false, + "conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2", + "conv_pos": 128, + "conv_pos_groups": 16, + "cross_sample_negatives": 0, + "dropout": 0.0, + "dropout_features": 0.1, + "dropout_input": 0.1, + "encoder_attention_heads": 16, + "encoder_embed_dim": 1024, + "encoder_ffn_embed_dim": 4096, + "encoder_layerdrop": 0.2, + "encoder_layers": 24, + "extractor_mode": "default", + "feature_grad_mult": 0.1, + "final_dim": 768, + "latent_dim": 0, + "latent_groups": 2, + "latent_temp": [ + 2.0, + 0.5, + 0.999995 + ], + "latent_vars": 320, + "layer_norm_first": false, + "logit_temp": 0.1, + "mask_channel_before": false, + "mask_channel_length": 10, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.0, + "mask_channel_selection": "static", + "mask_length": 10, + "mask_min_space": 1, + "mask_other": 0.0, + "mask_prob": 0.65, + "mask_selection": "static", + "negatives_from_everywhere": false, + "no_mask_channel_overlap": false, + "no_mask_overlap": false, + "num_negatives": 100, + "quantize_input": false, + "quantize_targets": true, + "quantizer_depth": 1, + "quantizer_factor": 3, + "same_quantizer": false, + "target_glu": false + }, + "task": { + "_name": "audio_pretraining", + "autoregressive": false, + "binarized_dataset": false, + "enable_padding": false, + "eval_wer": false, + "eval_wer_config": { + "beam": 5, + "constraints": null, + "decoding_format": null, + "diverse_beam_groups": -1, + "diverse_beam_strength": 0.5, + "diversity_rate": -1.0, + "iter_decode_eos_penalty": 0.0, + "iter_decode_force_max_iter": false, + "iter_decode_max_iter": 10, + "iter_decode_with_beam": 1, + "iter_decode_with_external_reranker": false, + "lenpen": 1.0, + "lm_path": null, + "lm_weight": 0.0, + "match_source_len": false, + "max_len_a": 0.0, + "max_len_b": 200, + "min_len": 1, + "nbest": 1, + "no_beamable_mm": false, + "no_early_stop": false, + "no_repeat_ngram_size": 0, + "no_seed_provided": false, + "prefix_size": 0, + "print_alignment": null, + "print_step": false, + "replace_unk": null, + "retain_dropout": false, + "retain_dropout_modules": null, + "retain_iter_history": false, + "sacrebleu": false, + "sampling": false, + "sampling_topk": -1, + "sampling_topp": -1.0, + "score_reference": false, + "temperature": 1.0, + "unkpen": 0.0, + "unnormalized": false + }, + "eval_wer_post_process": "letter", + "eval_wer_tokenizer": null, + "inferred_w2v_config": null, + "labels": null, + "max_sample_size": 320000, + "min_sample_size": 32000, + "normalize": false, + "num_batch_buckets": 0, + "precompute_mask_indices": false, + "sample_rate": 16000, + "tpu": true + } + }, + "w2v_path": "???" +} diff --git a/test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_large_lv60k_960h.json b/test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_large_lv60k_960h.json new file mode 100644 index 0000000000000000000000000000000000000000..8ba161b4a6dbf98cf4ebd2b3b837d92b3205a725 --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_large_lv60k_960h.json @@ -0,0 +1,146 @@ +{ + "_name": "wav2vec_ctc", + "activation_dropout": 0.1, + "apply_mask": true, + "attention_dropout": 0.0, + "blank_mode": "add", + "blank_weight": 0.0, + "conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]", + "dropout": 0.0, + "dropout_input": 0.0, + "encoder_embed_dim": 512, + "feature_grad_mult": 0.0, + "final_dropout": 0.0, + "freeze_finetune_updates": 10000, + "layerdrop": 0.1, + "mask_channel_before": false, + "mask_channel_length": 64, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.25, + "mask_channel_selection": "static", + "mask_length": 10, + "mask_min_space": 1, + "mask_other": 0.0, + "mask_prob": 0.5, + "mask_selection": "static", + "no_mask_channel_overlap": false, + "no_mask_overlap": false, + "no_pretrained_weights": false, + "normalize": true, + "w2v_args": { + "model": { + "_name": "wav2vec2", + "activation_dropout": 0.0, + "activation_fn": "gelu", + "attention_dropout": 0.1, + "codebook_negatives": 0, + "conv_bias": true, + "conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2", + "conv_pos": 128, + "conv_pos_groups": 16, + "cross_sample_negatives": 0, + "dropout": 0.0, + "dropout_features": 0.1, + "dropout_input": 0.1, + "encoder_attention_heads": 16, + "encoder_embed_dim": 1024, + "encoder_ffn_embed_dim": 4096, + "encoder_layerdrop": 0.0, + "encoder_layers": 24, + "extractor_mode": "layer_norm", + "feature_grad_mult": 1.0, + "final_dim": 768, + "latent_dim": 0, + "latent_groups": 2, + "latent_temp": [ + 2.0, + 0.1, + 0.999995 + ], + "latent_vars": 320, + "layer_norm_first": true, + "logit_temp": 0.1, + "mask_channel_before": false, + "mask_channel_length": 10, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.0, + "mask_channel_selection": "static", + "mask_length": 10, + "mask_min_space": 1, + "mask_other": 0.0, + "mask_prob": 0.65, + "mask_selection": "static", + "negatives_from_everywhere": false, + "no_mask_channel_overlap": false, + "no_mask_overlap": false, + "num_negatives": 100, + "quantize_input": false, + "quantize_targets": true, + "quantizer_depth": 1, + "quantizer_factor": 3, + "same_quantizer": false, + "target_glu": false + }, + "task": { + "_name": "audio_pretraining", + "autoregressive": false, + "binarized_dataset": false, + "enable_padding": false, + "eval_wer": false, + "eval_wer_config": { + "beam": 5, + "constraints": null, + "decoding_format": null, + "diverse_beam_groups": -1, + "diverse_beam_strength": 0.5, + "diversity_rate": -1.0, + "iter_decode_eos_penalty": 0.0, + "iter_decode_force_max_iter": false, + "iter_decode_max_iter": 10, + "iter_decode_with_beam": 1, + "iter_decode_with_external_reranker": false, + "lenpen": 1.0, + "lm_path": null, + "lm_weight": 0.0, + "match_source_len": false, + "max_len_a": 0.0, + "max_len_b": 200, + "min_len": 1, + "nbest": 1, + "no_beamable_mm": false, + "no_early_stop": false, + "no_repeat_ngram_size": 0, + "no_seed_provided": false, + "prefix_size": 0, + "print_alignment": null, + "print_step": false, + "replace_unk": null, + "retain_dropout": false, + "retain_dropout_modules": null, + "retain_iter_history": false, + "sacrebleu": false, + "sampling": false, + "sampling_topk": -1, + "sampling_topp": -1.0, + "score_reference": false, + "temperature": 1.0, + "unkpen": 0.0, + "unnormalized": false + }, + "eval_wer_post_process": "letter", + "eval_wer_tokenizer": null, + "inferred_w2v_config": null, + "labels": null, + "max_sample_size": 320000, + "min_sample_size": 32000, + "normalize": true, + "num_batch_buckets": 0, + "precompute_mask_indices": false, + "sample_rate": 16000, + "tpu": true + } + }, + "w2v_path": "???" +} diff --git a/test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_large_lv60k_self_960h.json b/test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_large_lv60k_self_960h.json new file mode 100644 index 0000000000000000000000000000000000000000..a65ccfe8cb52bfa4f3141389a1444bccac069fe7 --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_large_lv60k_self_960h.json @@ -0,0 +1,146 @@ +{ + "_name": "wav2vec_ctc", + "activation_dropout": 0.1, + "apply_mask": true, + "attention_dropout": 0.0, + "blank_mode": "add", + "blank_weight": 0.0, + "conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]", + "dropout": 0.0, + "dropout_input": 0.0, + "encoder_embed_dim": 768, + "feature_grad_mult": 0.0, + "final_dropout": 0.0, + "freeze_finetune_updates": 10000, + "layerdrop": 0.1, + "mask_channel_before": false, + "mask_channel_length": 64, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.1, + "mask_channel_selection": "static", + "mask_length": 10, + "mask_min_space": 1, + "mask_other": 0.0, + "mask_prob": 0.1, + "mask_selection": "static", + "no_mask_channel_overlap": false, + "no_mask_overlap": false, + "no_pretrained_weights": false, + "normalize": true, + "w2v_args": { + "model": { + "_name": "wav2vec2", + "activation_dropout": 0.0, + "activation_fn": "gelu", + "attention_dropout": 0.1, + "codebook_negatives": 0, + "conv_bias": true, + "conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2", + "conv_pos": 128, + "conv_pos_groups": 16, + "cross_sample_negatives": 0, + "dropout": 0.0, + "dropout_features": 0.1, + "dropout_input": 0.1, + "encoder_attention_heads": 16, + "encoder_embed_dim": 1024, + "encoder_ffn_embed_dim": 4096, + "encoder_layerdrop": 0.0, + "encoder_layers": 24, + "extractor_mode": "layer_norm", + "feature_grad_mult": 1.0, + "final_dim": 768, + "latent_dim": 0, + "latent_groups": 2, + "latent_temp": [ + 2.0, + 0.1, + 0.999995 + ], + "latent_vars": 320, + "layer_norm_first": true, + "logit_temp": 0.1, + "mask_channel_before": false, + "mask_channel_length": 10, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.0, + "mask_channel_selection": "static", + "mask_length": 10, + "mask_min_space": 1, + "mask_other": 0.0, + "mask_prob": 0.65, + "mask_selection": "static", + "negatives_from_everywhere": false, + "no_mask_channel_overlap": false, + "no_mask_overlap": false, + "num_negatives": 100, + "quantize_input": false, + "quantize_targets": true, + "quantizer_depth": 1, + "quantizer_factor": 3, + "same_quantizer": false, + "target_glu": false + }, + "task": { + "_name": "audio_pretraining", + "autoregressive": false, + "binarized_dataset": false, + "enable_padding": false, + "eval_wer": false, + "eval_wer_config": { + "beam": 5, + "constraints": null, + "decoding_format": null, + "diverse_beam_groups": -1, + "diverse_beam_strength": 0.5, + "diversity_rate": -1.0, + "iter_decode_eos_penalty": 0.0, + "iter_decode_force_max_iter": false, + "iter_decode_max_iter": 10, + "iter_decode_with_beam": 1, + "iter_decode_with_external_reranker": false, + "lenpen": 1.0, + "lm_path": null, + "lm_weight": 0.0, + "match_source_len": false, + "max_len_a": 0.0, + "max_len_b": 200, + "min_len": 1, + "nbest": 1, + "no_beamable_mm": false, + "no_early_stop": false, + "no_repeat_ngram_size": 0, + "no_seed_provided": false, + "prefix_size": 0, + "print_alignment": null, + "print_step": false, + "replace_unk": null, + "retain_dropout": false, + "retain_dropout_modules": null, + "retain_iter_history": false, + "sacrebleu": false, + "sampling": false, + "sampling_topk": -1, + "sampling_topp": -1.0, + "score_reference": false, + "temperature": 1.0, + "unkpen": 0.0, + "unnormalized": false + }, + "eval_wer_post_process": "letter", + "eval_wer_tokenizer": null, + "inferred_w2v_config": null, + "labels": null, + "max_sample_size": 320000, + "min_sample_size": 32000, + "normalize": true, + "num_batch_buckets": 0, + "precompute_mask_indices": false, + "sample_rate": 16000, + "tpu": true + } + }, + "w2v_path": "/private/home/abaevski/models/wav2vec2/wav2vec_vox_new.pt" +} diff --git a/test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_small.json b/test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_small.json new file mode 100644 index 0000000000000000000000000000000000000000..8fb7ff28e6c3bec4d8e355f7e9843566fb4c7ad4 --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_small.json @@ -0,0 +1,54 @@ +{ + "_name": "wav2vec2", + "activation_dropout": 0.0, + "activation_fn": "gelu", + "attention_dropout": 0.1, + "codebook_negatives": 0, + "conv_bias": false, + "conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2", + "conv_pos": 128, + "conv_pos_groups": 16, + "cross_sample_negatives": 0, + "dropout": 0.1, + "dropout_features": 0.1, + "dropout_input": 0.1, + "encoder_attention_heads": 12, + "encoder_embed_dim": 768, + "encoder_ffn_embed_dim": 3072, + "encoder_layerdrop": 0.05, + "encoder_layers": 12, + "extractor_mode": "default", + "feature_grad_mult": 0.1, + "final_dim": 256, + "latent_dim": 0, + "latent_groups": 2, + "latent_temp": [ + 2.0, + 0.5, + 0.999995 + ], + "latent_vars": 320, + "layer_norm_first": false, + "logit_temp": 0.1, + "mask_channel_before": false, + "mask_channel_length": 10, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.0, + "mask_channel_selection": "static", + "mask_length": 10, + "mask_min_space": 1, + "mask_other": 0.0, + "mask_prob": 0.65, + "mask_selection": "static", + "negatives_from_everywhere": false, + "no_mask_channel_overlap": false, + "no_mask_overlap": false, + "num_negatives": 100, + "quantize_input": false, + "quantize_targets": true, + "quantizer_depth": 1, + "quantizer_factor": 3, + "same_quantizer": false, + "target_glu": false +} diff --git a/test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_small_960h.json b/test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_small_960h.json new file mode 100644 index 0000000000000000000000000000000000000000..f0ee5af801d7ede979f6d393bf341fcba5fc9d4c --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_small_960h.json @@ -0,0 +1,146 @@ +{ + "_name": "wav2vec_ctc", + "activation_dropout": 0.1, + "apply_mask": true, + "attention_dropout": 0.0, + "blank_mode": "add", + "blank_weight": 0.0, + "conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]", + "dropout": 0.0, + "dropout_input": 0.0, + "encoder_embed_dim": 512, + "feature_grad_mult": 0.0, + "final_dropout": 0.0, + "freeze_finetune_updates": 0, + "layerdrop": 0.1, + "mask_channel_before": false, + "mask_channel_length": 64, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.1, + "mask_channel_selection": "static", + "mask_length": 10, + "mask_min_space": 1, + "mask_other": 0.0, + "mask_prob": 0.5, + "mask_selection": "static", + "no_mask_channel_overlap": false, + "no_mask_overlap": false, + "no_pretrained_weights": false, + "normalize": false, + "w2v_args": { + "model": { + "_name": "wav2vec2", + "activation_dropout": 0.0, + "activation_fn": "gelu", + "attention_dropout": 0.1, + "codebook_negatives": 0, + "conv_bias": false, + "conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2", + "conv_pos": 128, + "conv_pos_groups": 16, + "cross_sample_negatives": 0, + "dropout": 0.1, + "dropout_features": 0.1, + "dropout_input": 0.1, + "encoder_attention_heads": 12, + "encoder_embed_dim": 768, + "encoder_ffn_embed_dim": 3072, + "encoder_layerdrop": 0.05, + "encoder_layers": 12, + "extractor_mode": "default", + "feature_grad_mult": 0.1, + "final_dim": 256, + "latent_dim": 0, + "latent_groups": 2, + "latent_temp": [ + 2, + 0.5, + 0.999995 + ], + "latent_vars": 320, + "layer_norm_first": false, + "logit_temp": 0.1, + "mask_channel_before": false, + "mask_channel_length": 10, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.0, + "mask_channel_selection": "static", + "mask_length": 10, + "mask_min_space": 1, + "mask_other": 0.0, + "mask_prob": 0.65, + "mask_selection": "static", + "negatives_from_everywhere": false, + "no_mask_channel_overlap": false, + "no_mask_overlap": false, + "num_negatives": 100, + "quantize_input": false, + "quantize_targets": true, + "quantizer_depth": 1, + "quantizer_factor": 3, + "same_quantizer": false, + "target_glu": false + }, + "task": { + "_name": "audio_pretraining", + "autoregressive": false, + "binarized_dataset": false, + "enable_padding": false, + "eval_wer": false, + "eval_wer_config": { + "beam": 5, + "constraints": null, + "decoding_format": null, + "diverse_beam_groups": -1, + "diverse_beam_strength": 0.5, + "diversity_rate": -1.0, + "iter_decode_eos_penalty": 0.0, + "iter_decode_force_max_iter": false, + "iter_decode_max_iter": 10, + "iter_decode_with_beam": 1, + "iter_decode_with_external_reranker": false, + "lenpen": 1.0, + "lm_path": null, + "lm_weight": 0.0, + "match_source_len": false, + "max_len_a": 0.0, + "max_len_b": 200, + "min_len": 1, + "nbest": 1, + "no_beamable_mm": false, + "no_early_stop": false, + "no_repeat_ngram_size": 0, + "no_seed_provided": false, + "prefix_size": 0, + "print_alignment": null, + "print_step": false, + "replace_unk": null, + "retain_dropout": false, + "retain_dropout_modules": null, + "retain_iter_history": false, + "sacrebleu": false, + "sampling": false, + "sampling_topk": -1, + "sampling_topp": -1.0, + "score_reference": false, + "temperature": 1.0, + "unkpen": 0.0, + "unnormalized": false + }, + "eval_wer_post_process": "letter", + "eval_wer_tokenizer": null, + "inferred_w2v_config": null, + "labels": null, + "max_sample_size": 250000, + "min_sample_size": 32000, + "normalize": false, + "num_batch_buckets": 0, + "precompute_mask_indices": false, + "sample_rate": 16000, + "tpu": true + } + }, + "w2v_path": "???" +} diff --git a/test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_vox_new.json b/test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_vox_new.json new file mode 100644 index 0000000000000000000000000000000000000000..d58e303a75843afeeca10c1b06be7b59fda34132 --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_vox_new.json @@ -0,0 +1,54 @@ +{ + "_name": "wav2vec2", + "activation_dropout": 0.0, + "activation_fn": "gelu", + "attention_dropout": 0.1, + "codebook_negatives": 0, + "conv_bias": true, + "conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2", + "conv_pos": 128, + "conv_pos_groups": 16, + "cross_sample_negatives": 0, + "dropout": 0.0, + "dropout_features": 0.1, + "dropout_input": 0.1, + "encoder_attention_heads": 16, + "encoder_embed_dim": 1024, + "encoder_ffn_embed_dim": 4096, + "encoder_layerdrop": 0.0, + "encoder_layers": 24, + "extractor_mode": "layer_norm", + "feature_grad_mult": 1.0, + "final_dim": 768, + "latent_dim": 0, + "latent_groups": 2, + "latent_temp": [ + 2.0, + 0.1, + 0.999995 + ], + "latent_vars": 320, + "layer_norm_first": true, + "logit_temp": 0.1, + "mask_channel_before": false, + "mask_channel_length": 10, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.0, + "mask_channel_selection": "static", + "mask_length": 10, + "mask_min_space": 1, + "mask_other": 0.0, + "mask_prob": 0.65, + "mask_selection": "static", + "negatives_from_everywhere": false, + "no_mask_channel_overlap": false, + "no_mask_overlap": false, + "num_negatives": 100, + "quantize_input": false, + "quantize_targets": true, + "quantizer_depth": 1, + "quantizer_factor": 3, + "same_quantizer": false, + "target_glu": false +} diff --git a/test/torchaudio_unittest/assets/wav2vec2/fairseq/xlsr_53_56k.json b/test/torchaudio_unittest/assets/wav2vec2/fairseq/xlsr_53_56k.json new file mode 100644 index 0000000000000000000000000000000000000000..098f60d55a397fb2ac9859b697c121983b2a8eca --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/fairseq/xlsr_53_56k.json @@ -0,0 +1,51 @@ +{ + "_name": "wav2vec2", + "activation_dropout": 0.0, + "activation_fn": "gelu", + "attention_dropout": 0.0, + "codebook_negatives": 0, + "conv_bias": true, + "conv_feature_layers": "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2", + "conv_pos": 128, + "conv_pos_groups": 16, + "cross_sample_negatives": 0, + "dropout": 0.0, + "dropout_features": 0.0, + "dropout_input": 0.0, + "encoder_attention_heads": 16, + "encoder_embed_dim": 1024, + "encoder_ffn_embed_dim": 4096, + "encoder_layerdrop": 0.0, + "encoder_layers": 24, + "extractor_mode": "layer_norm", + "feature_grad_mult": 1.0, + "final_dim": 768, + "latent_dim": 0, + "latent_groups": 2, + "latent_temp": [ + 2.0, + 0.1, + 0.999995 + ], + "latent_vars": 320, + "layer_norm_first": true, + "logit_temp": 0.1, + "mask_channel_length": 10, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.0, + "mask_channel_selection": "static", + "mask_length": 10, + "mask_min_space": 1, + "mask_other": 0.0, + "mask_prob": 0.65, + "mask_selection": "static", + "negatives_from_everywhere": false, + "no_mask_channel_overlap": false, + "no_mask_overlap": false, + "num_negatives": 100, + "quantize_input": false, + "quantize_targets": true, + "same_quantizer": false, + "target_glu": false +} diff --git a/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-base-10k-voxpopuli.json b/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-base-10k-voxpopuli.json new file mode 100644 index 0000000000000000000000000000000000000000..927428cb10f0940cc4581b6d739492c7ed685f94 --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-base-10k-voxpopuli.json @@ -0,0 +1,68 @@ +{ + "activation_dropout": 0.1, + "apply_spec_augment": true, + "architectures": [ + "Wav2Vec2Model" + ], + "attention_dropout": 0.1, + "bos_token_id": 1, + "conv_bias": false, + "conv_dim": [ + 512, + 512, + 512, + 512, + 512, + 512, + 512 + ], + "conv_kernel": [ + 10, + 3, + 3, + 3, + 3, + 2, + 2 + ], + "conv_stride": [ + 5, + 2, + 2, + 2, + 2, + 2, + 2 + ], + "ctc_loss_reduction": "sum", + "ctc_zero_infinity": false, + "do_stable_layer_norm": false, + "eos_token_id": 2, + "feat_extract_activation": "gelu", + "feat_extract_dropout": 0.0, + "feat_extract_norm": "group", + "feat_proj_dropout": 0.1, + "final_dropout": 0.1, + "gradient_checkpointing": false, + "hidden_act": "gelu", + "hidden_dropout": 0.1, + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-05, + "layerdrop": 0.1, + "mask_feature_length": 10, + "mask_feature_prob": 0.0, + "mask_time_length": 10, + "mask_time_prob": 0.05, + "model_type": "wav2vec2", + "num_attention_heads": 12, + "num_conv_pos_embedding_groups": 16, + "num_conv_pos_embeddings": 128, + "num_feat_extract_layers": 7, + "num_hidden_layers": 12, + "pad_token_id": 0, + "transformers_version": "4.5.1", + "vocab_size": 32 +} diff --git a/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-base-960h.json b/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-base-960h.json new file mode 100644 index 0000000000000000000000000000000000000000..3a1c7f1a4b3dd14987fde1314a940d33fd9f3c79 --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-base-960h.json @@ -0,0 +1,68 @@ +{ + "activation_dropout": 0.1, + "apply_spec_augment": true, + "architectures": [ + "Wav2Vec2ForCTC" + ], + "attention_dropout": 0.1, + "bos_token_id": 1, + "conv_bias": false, + "conv_dim": [ + 512, + 512, + 512, + 512, + 512, + 512, + 512 + ], + "conv_kernel": [ + 10, + 3, + 3, + 3, + 3, + 2, + 2 + ], + "conv_stride": [ + 5, + 2, + 2, + 2, + 2, + 2, + 2 + ], + "ctc_loss_reduction": "sum", + "ctc_zero_infinity": false, + "do_stable_layer_norm": false, + "eos_token_id": 2, + "feat_extract_activation": "gelu", + "feat_extract_dropout": 0.0, + "feat_extract_norm": "group", + "feat_proj_dropout": 0.1, + "final_dropout": 0.1, + "gradient_checkpointing": false, + "hidden_act": "gelu", + "hidden_dropout": 0.1, + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-05, + "layerdrop": 0.1, + "mask_feature_length": 10, + "mask_feature_prob": 0.0, + "mask_time_length": 10, + "mask_time_prob": 0.05, + "model_type": "wav2vec2", + "num_attention_heads": 12, + "num_conv_pos_embedding_groups": 16, + "num_conv_pos_embeddings": 128, + "num_feat_extract_layers": 7, + "num_hidden_layers": 12, + "pad_token_id": 0, + "transformers_version": "4.5.1", + "vocab_size": 32 +} diff --git a/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-base.json b/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-base.json new file mode 100644 index 0000000000000000000000000000000000000000..7927c2e4a98c264917030ff5d7c40937496ef70f --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-base.json @@ -0,0 +1,77 @@ +{ + "activation_dropout": 0.0, + "apply_spec_augment": true, + "architectures": [ + "Wav2Vec2Model" + ], + "attention_dropout": 0.1, + "bos_token_id": 1, + "conv_bias": false, + "conv_dim": [ + 512, + 512, + 512, + 512, + 512, + 512, + 512 + ], + "conv_kernel": [ + 10, + 3, + 3, + 3, + 3, + 2, + 2 + ], + "conv_stride": [ + 5, + 2, + 2, + 2, + 2, + 2, + 2 + ], + "ctc_loss_reduction": "sum", + "ctc_zero_infinity": false, + "do_stable_layer_norm": false, + "eos_token_id": 2, + "feat_extract_activation": "gelu", + "feat_extract_norm": "group", + "feat_proj_dropout": 0.1, + "final_dropout": 0.0, + "freeze_feat_extract_train": true, + "gradient_checkpointing": true, + "hidden_act": "gelu", + "hidden_dropout": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-05, + "layerdrop": 0.05, + "mask_channel_length": 10, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.0, + "mask_channel_selection": "static", + "mask_feature_length": 10, + "mask_feature_prob": 0.0, + "mask_time_length": 10, + "mask_time_min_space": 1, + "mask_time_other": 0.0, + "mask_time_prob": 0.05, + "mask_time_selection": "static", + "model_type": "wav2vec2", + "no_mask_channel_overlap": false, + "no_mask_time_overlap": false, + "num_attention_heads": 12, + "num_conv_pos_embedding_groups": 16, + "num_conv_pos_embeddings": 128, + "num_feat_extract_layers": 7, + "num_hidden_layers": 12, + "pad_token_id": 0, + "transformers_version": "4.5.1", + "vocab_size": 32 +} diff --git a/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-large-960h-lv60-self.json b/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-large-960h-lv60-self.json new file mode 100644 index 0000000000000000000000000000000000000000..e9d79893ac5ada65d0ccecc35e11db1ecb759417 --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-large-960h-lv60-self.json @@ -0,0 +1,68 @@ +{ + "activation_dropout": 0.1, + "apply_spec_augment": true, + "architectures": [ + "Wav2Vec2ForCTC" + ], + "attention_dropout": 0.1, + "bos_token_id": 1, + "conv_bias": true, + "conv_dim": [ + 512, + 512, + 512, + 512, + 512, + 512, + 512 + ], + "conv_kernel": [ + 10, + 3, + 3, + 3, + 3, + 2, + 2 + ], + "conv_stride": [ + 5, + 2, + 2, + 2, + 2, + 2, + 2 + ], + "ctc_loss_reduction": "sum", + "ctc_zero_infinity": false, + "do_stable_layer_norm": true, + "eos_token_id": 2, + "feat_extract_activation": "gelu", + "feat_extract_dropout": 0.0, + "feat_extract_norm": "layer", + "feat_proj_dropout": 0.1, + "final_dropout": 0.1, + "gradient_checkpointing": false, + "hidden_act": "gelu", + "hidden_dropout": 0.1, + "hidden_dropout_prob": 0.1, + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 4096, + "layer_norm_eps": 1e-05, + "layerdrop": 0.1, + "mask_feature_length": 10, + "mask_feature_prob": 0.0, + "mask_time_length": 10, + "mask_time_prob": 0.05, + "model_type": "wav2vec2", + "num_attention_heads": 16, + "num_conv_pos_embedding_groups": 16, + "num_conv_pos_embeddings": 128, + "num_feat_extract_layers": 7, + "num_hidden_layers": 24, + "pad_token_id": 0, + "transformers_version": "4.5.1", + "vocab_size": 32 +} diff --git a/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-large-960h-lv60.json b/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-large-960h-lv60.json new file mode 100644 index 0000000000000000000000000000000000000000..e9d79893ac5ada65d0ccecc35e11db1ecb759417 --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-large-960h-lv60.json @@ -0,0 +1,68 @@ +{ + "activation_dropout": 0.1, + "apply_spec_augment": true, + "architectures": [ + "Wav2Vec2ForCTC" + ], + "attention_dropout": 0.1, + "bos_token_id": 1, + "conv_bias": true, + "conv_dim": [ + 512, + 512, + 512, + 512, + 512, + 512, + 512 + ], + "conv_kernel": [ + 10, + 3, + 3, + 3, + 3, + 2, + 2 + ], + "conv_stride": [ + 5, + 2, + 2, + 2, + 2, + 2, + 2 + ], + "ctc_loss_reduction": "sum", + "ctc_zero_infinity": false, + "do_stable_layer_norm": true, + "eos_token_id": 2, + "feat_extract_activation": "gelu", + "feat_extract_dropout": 0.0, + "feat_extract_norm": "layer", + "feat_proj_dropout": 0.1, + "final_dropout": 0.1, + "gradient_checkpointing": false, + "hidden_act": "gelu", + "hidden_dropout": 0.1, + "hidden_dropout_prob": 0.1, + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 4096, + "layer_norm_eps": 1e-05, + "layerdrop": 0.1, + "mask_feature_length": 10, + "mask_feature_prob": 0.0, + "mask_time_length": 10, + "mask_time_prob": 0.05, + "model_type": "wav2vec2", + "num_attention_heads": 16, + "num_conv_pos_embedding_groups": 16, + "num_conv_pos_embeddings": 128, + "num_feat_extract_layers": 7, + "num_hidden_layers": 24, + "pad_token_id": 0, + "transformers_version": "4.5.1", + "vocab_size": 32 +} diff --git a/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-large-960h.json b/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-large-960h.json new file mode 100644 index 0000000000000000000000000000000000000000..e50233c00e5838b921439569c3fd344f121f8861 --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-large-960h.json @@ -0,0 +1,68 @@ +{ + "activation_dropout": 0.1, + "apply_spec_augment": true, + "architectures": [ + "Wav2Vec2ForCTC" + ], + "attention_dropout": 0.1, + "bos_token_id": 1, + "conv_bias": false, + "conv_dim": [ + 512, + 512, + 512, + 512, + 512, + 512, + 512 + ], + "conv_kernel": [ + 10, + 3, + 3, + 3, + 3, + 2, + 2 + ], + "conv_stride": [ + 5, + 2, + 2, + 2, + 2, + 2, + 2 + ], + "ctc_loss_reduction": "sum", + "ctc_zero_infinity": false, + "do_stable_layer_norm": false, + "eos_token_id": 2, + "feat_extract_activation": "gelu", + "feat_extract_dropout": 0.0, + "feat_extract_norm": "group", + "feat_proj_dropout": 0.1, + "final_dropout": 0.1, + "gradient_checkpointing": false, + "hidden_act": "gelu", + "hidden_dropout": 0.1, + "hidden_dropout_prob": 0.1, + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 4096, + "layer_norm_eps": 1e-05, + "layerdrop": 0.1, + "mask_feature_length": 10, + "mask_feature_prob": 0.0, + "mask_time_length": 10, + "mask_time_prob": 0.05, + "model_type": "wav2vec2", + "num_attention_heads": 16, + "num_conv_pos_embedding_groups": 16, + "num_conv_pos_embeddings": 128, + "num_feat_extract_layers": 7, + "num_hidden_layers": 24, + "pad_token_id": 0, + "transformers_version": "4.5.1", + "vocab_size": 32 +} diff --git a/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-large-lv60.json b/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-large-lv60.json new file mode 100644 index 0000000000000000000000000000000000000000..ba3d2c2e217bae535d0ab6d3a837aa6a12de08f7 --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-large-lv60.json @@ -0,0 +1,68 @@ +{ + "activation_dropout": 0.1, + "apply_spec_augment": true, + "architectures": [ + "Wav2Vec2Model" + ], + "attention_dropout": 0.1, + "bos_token_id": 1, + "conv_bias": true, + "conv_dim": [ + 512, + 512, + 512, + 512, + 512, + 512, + 512 + ], + "conv_kernel": [ + 10, + 3, + 3, + 3, + 3, + 2, + 2 + ], + "conv_stride": [ + 5, + 2, + 2, + 2, + 2, + 2, + 2 + ], + "ctc_loss_reduction": "sum", + "ctc_zero_infinity": false, + "do_stable_layer_norm": true, + "eos_token_id": 2, + "feat_extract_activation": "gelu", + "feat_extract_dropout": 0.0, + "feat_extract_norm": "layer", + "feat_proj_dropout": 0.1, + "final_dropout": 0.1, + "gradient_checkpointing": false, + "hidden_act": "gelu", + "hidden_dropout": 0.1, + "hidden_dropout_prob": 0.1, + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 4096, + "layer_norm_eps": 1e-05, + "layerdrop": 0.1, + "mask_feature_length": 10, + "mask_feature_prob": 0.0, + "mask_time_length": 10, + "mask_time_prob": 0.05, + "model_type": "wav2vec2", + "num_attention_heads": 16, + "num_conv_pos_embedding_groups": 16, + "num_conv_pos_embeddings": 128, + "num_feat_extract_layers": 7, + "num_hidden_layers": 24, + "pad_token_id": 0, + "transformers_version": "4.5.1", + "vocab_size": 32 +} diff --git a/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-large-xlsr-53-german.json b/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-large-xlsr-53-german.json new file mode 100644 index 0000000000000000000000000000000000000000..120f142b865d7c5513714a4f6523e4b2c5b8b491 --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-large-xlsr-53-german.json @@ -0,0 +1,68 @@ +{ + "activation_dropout": 0.1, + "apply_spec_augment": true, + "architectures": [ + "Wav2Vec2ForCTC" + ], + "attention_dropout": 0.1, + "bos_token_id": 1, + "conv_bias": true, + "conv_dim": [ + 512, + 512, + 512, + 512, + 512, + 512, + 512 + ], + "conv_kernel": [ + 10, + 3, + 3, + 3, + 3, + 2, + 2 + ], + "conv_stride": [ + 5, + 2, + 2, + 2, + 2, + 2, + 2 + ], + "ctc_loss_reduction": "sum", + "ctc_zero_infinity": false, + "do_stable_layer_norm": true, + "eos_token_id": 2, + "feat_extract_activation": "gelu", + "feat_extract_dropout": 0.0, + "feat_extract_norm": "layer", + "feat_proj_dropout": 0.1, + "final_dropout": 0.1, + "gradient_checkpointing": false, + "hidden_act": "gelu", + "hidden_dropout": 0.1, + "hidden_dropout_prob": 0.1, + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 4096, + "layer_norm_eps": 1e-05, + "layerdrop": 0.1, + "mask_feature_length": 10, + "mask_feature_prob": 0.0, + "mask_time_length": 10, + "mask_time_prob": 0.05, + "model_type": "wav2vec2", + "num_attention_heads": 16, + "num_conv_pos_embedding_groups": 16, + "num_conv_pos_embeddings": 128, + "num_feat_extract_layers": 7, + "num_hidden_layers": 24, + "pad_token_id": 0, + "transformers_version": "4.5.1", + "vocab_size": 36 +} diff --git a/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-large-xlsr-53.json b/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-large-xlsr-53.json new file mode 100644 index 0000000000000000000000000000000000000000..80f0f61e8e5d84ec4594928504e3c5112c6dd302 --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-large-xlsr-53.json @@ -0,0 +1,75 @@ +{ + "activation_dropout": 0.0, + "apply_spec_augment": true, + "architectures": [ + "Wav2Vec2Model" + ], + "attention_dropout": 0.1, + "bos_token_id": 1, + "conv_bias": true, + "conv_dim": [ + 512, + 512, + 512, + 512, + 512, + 512, + 512 + ], + "conv_kernel": [ + 10, + 3, + 3, + 3, + 3, + 2, + 2 + ], + "conv_stride": [ + 5, + 2, + 2, + 2, + 2, + 2, + 2 + ], + "ctc_loss_reduction": "sum", + "ctc_zero_infinity": false, + "do_stable_layer_norm": true, + "eos_token_id": 2, + "feat_extract_activation": "gelu", + "feat_extract_dropout": 0.0, + "feat_extract_norm": "layer", + "feat_proj_dropout": 0.1, + "final_dropout": 0.0, + "gradient_checkpointing": false, + "hidden_act": "gelu", + "hidden_dropout": 0.1, + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 4096, + "layer_norm_eps": 1e-05, + "layerdrop": 0.1, + "mask_channel_length": 10, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.0, + "mask_channel_selection": "static", + "mask_feature_length": 10, + "mask_feature_prob": 0.0, + "mask_time_length": 10, + "mask_time_min_space": 1, + "mask_time_other": 0.0, + "mask_time_prob": 0.075, + "mask_time_selection": "static", + "model_type": "wav2vec2", + "num_attention_heads": 16, + "num_conv_pos_embedding_groups": 16, + "num_conv_pos_embeddings": 128, + "num_feat_extract_layers": 7, + "num_hidden_layers": 24, + "pad_token_id": 0, + "transformers_version": "4.5.1", + "vocab_size": 32 +} diff --git a/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-large.json b/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-large.json new file mode 100644 index 0000000000000000000000000000000000000000..db1635eb0cbf1d5ad71f6686c893397fe714a241 --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/huggingface/facebook/wav2vec2-large.json @@ -0,0 +1,68 @@ +{ + "activation_dropout": 0.1, + "apply_spec_augment": true, + "architectures": [ + "Wav2Vec2Model" + ], + "attention_dropout": 0.1, + "bos_token_id": 1, + "conv_bias": false, + "conv_dim": [ + 512, + 512, + 512, + 512, + 512, + 512, + 512 + ], + "conv_kernel": [ + 10, + 3, + 3, + 3, + 3, + 2, + 2 + ], + "conv_stride": [ + 5, + 2, + 2, + 2, + 2, + 2, + 2 + ], + "ctc_loss_reduction": "sum", + "ctc_zero_infinity": false, + "do_stable_layer_norm": false, + "eos_token_id": 2, + "feat_extract_activation": "gelu", + "feat_extract_dropout": 0.0, + "feat_extract_norm": "group", + "feat_proj_dropout": 0.1, + "final_dropout": 0.1, + "gradient_checkpointing": false, + "hidden_act": "gelu", + "hidden_dropout": 0.1, + "hidden_dropout_prob": 0.1, + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 4096, + "layer_norm_eps": 1e-05, + "layerdrop": 0.1, + "mask_feature_length": 10, + "mask_feature_prob": 0.0, + "mask_time_length": 10, + "mask_time_prob": 0.05, + "model_type": "wav2vec2", + "num_attention_heads": 16, + "num_conv_pos_embedding_groups": 16, + "num_conv_pos_embeddings": 128, + "num_feat_extract_layers": 7, + "num_hidden_layers": 24, + "pad_token_id": 0, + "transformers_version": "4.5.1", + "vocab_size": 32 +} diff --git a/test/torchaudio_unittest/assets/wav2vec2/huggingface/generate_huggingface_model_config.py b/test/torchaudio_unittest/assets/wav2vec2/huggingface/generate_huggingface_model_config.py new file mode 100644 index 0000000000000000000000000000000000000000..d2217930104f6071fdf5a547e1ff02693c5b7b49 --- /dev/null +++ b/test/torchaudio_unittest/assets/wav2vec2/huggingface/generate_huggingface_model_config.py @@ -0,0 +1,37 @@ +import os +import json + +from transformers import Wav2Vec2Model + +_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def _main(): + keys = [ + # pretrained + "facebook/wav2vec2-base", + "facebook/wav2vec2-large", + "facebook/wav2vec2-large-lv60", + "facebook/wav2vec2-base-10k-voxpopuli", + "facebook/wav2vec2-large-xlsr-53", + # finetuned + "facebook/wav2vec2-base-960h", + "facebook/wav2vec2-large-960h", + "facebook/wav2vec2-large-960h-lv60", + "facebook/wav2vec2-large-960h-lv60-self", + "facebook/wav2vec2-large-xlsr-53-german", + ] + for key in keys: + path = os.path.join(_THIS_DIR, f'{key}.json') + print('Generating ', path) + cfg = Wav2Vec2Model.from_pretrained(key).config + cfg = json.loads(cfg.to_json_string()) + del cfg['_name_or_path'] + + with open(path, 'w') as file_: + file_.write(json.dumps(cfg, indent=4, sort_keys=True)) + file_.write('\n') + + +if __name__ == '__main__': + _main() diff --git a/test/torchaudio_unittest/backend/__init__.py b/test/torchaudio_unittest/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/torchaudio_unittest/backend/common.py b/test/torchaudio_unittest/backend/common.py new file mode 100644 index 0000000000000000000000000000000000000000..dc8e90bf77517775bc94e06cca506cb7cb89a203 --- /dev/null +++ b/test/torchaudio_unittest/backend/common.py @@ -0,0 +1,25 @@ +from torchaudio_unittest.common_utils import sox_utils + + +def get_encoding(ext, dtype): + exts = { + 'mp3', + 'flac', + 'vorbis', + } + encodings = { + 'float32': 'PCM_F', + 'int32': 'PCM_S', + 'int16': 'PCM_S', + 'uint8': 'PCM_U', + } + return ext.upper() if ext in exts else encodings[dtype] + + +def get_bits_per_sample(ext, dtype): + bits_per_samples = { + 'flac': 24, + 'mp3': 0, + 'vorbis': 0, + } + return bits_per_samples.get(ext, sox_utils.get_bit_depth(dtype)) diff --git a/test/torchaudio_unittest/backend/soundfile/__init__.py b/test/torchaudio_unittest/backend/soundfile/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/torchaudio_unittest/backend/soundfile/common.py b/test/torchaudio_unittest/backend/soundfile/common.py new file mode 100644 index 0000000000000000000000000000000000000000..c6b014dd4ca7d5b75070545ca6484d81b96dc2ba --- /dev/null +++ b/test/torchaudio_unittest/backend/soundfile/common.py @@ -0,0 +1,57 @@ +import itertools +from unittest import skipIf + +from parameterized import parameterized +from torchaudio._internal.module_utils import is_module_available + + +def name_func(func, _, params): + return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}' + + +def dtype2subtype(dtype): + return { + "float64": "DOUBLE", + "float32": "FLOAT", + "int32": "PCM_32", + "int16": "PCM_16", + "uint8": "PCM_U8", + "int8": "PCM_S8", + }[dtype] + + +def skipIfFormatNotSupported(fmt): + fmts = [] + if is_module_available("soundfile"): + import soundfile + + fmts = soundfile.available_formats() + return skipIf(fmt not in fmts, f'"{fmt}" is not supported by soundfile') + return skipIf(True, '"soundfile" not available.') + + +def parameterize(*params): + return parameterized.expand(list(itertools.product(*params)), name_func=name_func) + + +def fetch_wav_subtype(dtype, encoding, bits_per_sample): + subtype = { + (None, None): dtype2subtype(dtype), + (None, 8): "PCM_U8", + ('PCM_U', None): "PCM_U8", + ('PCM_U', 8): "PCM_U8", + ('PCM_S', None): "PCM_32", + ('PCM_S', 16): "PCM_16", + ('PCM_S', 32): "PCM_32", + ('PCM_F', None): "FLOAT", + ('PCM_F', 32): "FLOAT", + ('PCM_F', 64): "DOUBLE", + ('ULAW', None): "ULAW", + ('ULAW', 8): "ULAW", + ('ALAW', None): "ALAW", + ('ALAW', 8): "ALAW", + }.get((encoding, bits_per_sample)) + if subtype: + return subtype + raise ValueError( + f"wav does not support ({encoding}, {bits_per_sample}).") diff --git a/test/torchaudio_unittest/backend/soundfile/info_test.py b/test/torchaudio_unittest/backend/soundfile/info_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b6b8722410a468286c059fd1501f554682242160 --- /dev/null +++ b/test/torchaudio_unittest/backend/soundfile/info_test.py @@ -0,0 +1,190 @@ +from unittest.mock import patch +import warnings +import tarfile + +import torch +from torchaudio.backend import soundfile_backend +from torchaudio._internal import module_utils as _mod_utils + +from torchaudio_unittest.common_utils import ( + TempDirMixin, + PytorchTestCase, + skipIfNoModule, + get_wav_data, + save_wav, + nested_params, +) +from torchaudio_unittest.backend.common import ( + get_bits_per_sample, + get_encoding, +) +from .common import skipIfFormatNotSupported, parameterize + +if _mod_utils.is_module_available("soundfile"): + import soundfile + + +@skipIfNoModule("soundfile") +class TestInfo(TempDirMixin, PytorchTestCase): + @parameterize( + ["float32", "int32", "int16", "uint8"], [8000, 16000], [1, 2], + ) + def test_wav(self, dtype, sample_rate, num_channels): + """`soundfile_backend.info` can check wav file correctly""" + duration = 1 + path = self.get_temp_path("data.wav") + data = get_wav_data( + dtype, num_channels, normalize=False, num_frames=duration * sample_rate + ) + save_wav(path, data, sample_rate) + info = soundfile_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + assert info.bits_per_sample == get_bits_per_sample("wav", dtype) + assert info.encoding == get_encoding("wav", dtype) + + @parameterize([8000, 16000], [1, 2]) + @skipIfFormatNotSupported("FLAC") + def test_flac(self, sample_rate, num_channels): + """`soundfile_backend.info` can check flac file correctly""" + duration = 1 + num_frames = sample_rate * duration + data = torch.randn(num_frames, num_channels).numpy() + path = self.get_temp_path("data.flac") + soundfile.write(path, data, sample_rate) + + info = soundfile_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == num_frames + assert info.num_channels == num_channels + assert info.bits_per_sample == 16 + assert info.encoding == "FLAC" + + @parameterize([8000, 16000], [1, 2]) + @skipIfFormatNotSupported("OGG") + def test_ogg(self, sample_rate, num_channels): + """`soundfile_backend.info` can check ogg file correctly""" + duration = 1 + num_frames = sample_rate * duration + data = torch.randn(num_frames, num_channels).numpy() + path = self.get_temp_path("data.ogg") + soundfile.write(path, data, sample_rate) + + info = soundfile_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + assert info.bits_per_sample == 0 + assert info.encoding == "VORBIS" + + @nested_params( + [8000, 16000], + [1, 2], + [ + ('PCM_24', 24), + ('PCM_32', 32) + ], + ) + @skipIfFormatNotSupported("NIST") + def test_sphere(self, sample_rate, num_channels, subtype_and_bit_depth): + """`soundfile_backend.info` can check sph file correctly""" + duration = 1 + num_frames = sample_rate * duration + data = torch.randn(num_frames, num_channels).numpy() + path = self.get_temp_path("data.nist") + subtype, bits_per_sample = subtype_and_bit_depth + soundfile.write(path, data, sample_rate, subtype=subtype) + + info = soundfile_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + assert info.bits_per_sample == bits_per_sample + assert info.encoding == "PCM_S" + + def test_unknown_subtype_warning(self): + """soundfile_backend.info issues a warning when the subtype is unknown + + This will happen if a new subtype is supported in SoundFile: the _SUBTYPE_TO_BITS_PER_SAMPLE + dict should be updated. + """ + def _mock_info_func(_): + class MockSoundFileInfo: + samplerate = 8000 + frames = 356 + channels = 2 + subtype = 'UNSEEN_SUBTYPE' + format = 'UNKNOWN' + return MockSoundFileInfo() + + with patch("soundfile.info", _mock_info_func): + with warnings.catch_warnings(record=True) as w: + info = soundfile_backend.info("foo") + assert len(w) == 1 + assert "UNSEEN_SUBTYPE subtype is unknown to TorchAudio" in str(w[-1].message) + assert info.bits_per_sample == 0 + + +@skipIfNoModule("soundfile") +class TestFileObject(TempDirMixin, PytorchTestCase): + def _test_fileobj(self, ext, subtype, bits_per_sample): + """Query audio via file-like object works""" + duration = 2 + sample_rate = 16000 + num_channels = 2 + num_frames = sample_rate * duration + path = self.get_temp_path(f'test.{ext}') + + data = torch.randn(num_frames, num_channels).numpy() + soundfile.write(path, data, sample_rate, subtype=subtype) + + with open(path, 'rb') as fileobj: + info = soundfile_backend.info(fileobj) + assert info.sample_rate == sample_rate + assert info.num_frames == num_frames + assert info.num_channels == num_channels + assert info.bits_per_sample == bits_per_sample + assert info.encoding == "FLAC" if ext == 'flac' else "PCM_S" + + def test_fileobj_wav(self): + """Loading audio via file-like object works""" + self._test_fileobj('wav', 'PCM_16', 16) + + @skipIfFormatNotSupported("FLAC") + def test_fileobj_flac(self): + """Loading audio via file-like object works""" + self._test_fileobj('flac', 'PCM_16', 16) + + def _test_tarobj(self, ext, subtype, bits_per_sample): + """Query compressed audio via file-like object works""" + duration = 2 + sample_rate = 16000 + num_channels = 2 + num_frames = sample_rate * duration + audio_file = f'test.{ext}' + audio_path = self.get_temp_path(audio_file) + archive_path = self.get_temp_path('archive.tar.gz') + + data = torch.randn(num_frames, num_channels).numpy() + soundfile.write(audio_path, data, sample_rate, subtype=subtype) + + with tarfile.TarFile(archive_path, 'w') as tarobj: + tarobj.add(audio_path, arcname=audio_file) + with tarfile.TarFile(archive_path, 'r') as tarobj: + fileobj = tarobj.extractfile(audio_file) + info = soundfile_backend.info(fileobj) + assert info.sample_rate == sample_rate + assert info.num_frames == num_frames + assert info.num_channels == num_channels + assert info.bits_per_sample == bits_per_sample + assert info.encoding == "FLAC" if ext == 'flac' else "PCM_S" + + def test_tarobj_wav(self): + """Query compressed audio via file-like object works""" + self._test_tarobj('wav', 'PCM_16', 16) + + @skipIfFormatNotSupported("FLAC") + def test_tarobj_flac(self): + """Query compressed audio via file-like object works""" + self._test_tarobj('flac', 'PCM_16', 16) diff --git a/test/torchaudio_unittest/backend/soundfile/load_test.py b/test/torchaudio_unittest/backend/soundfile/load_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0e3a240d26596440ce38322b3dc26164849ff9ae --- /dev/null +++ b/test/torchaudio_unittest/backend/soundfile/load_test.py @@ -0,0 +1,357 @@ +import os +import tarfile +from unittest.mock import patch + +import torch +from torchaudio._internal import module_utils as _mod_utils +from torchaudio.backend import soundfile_backend +from parameterized import parameterized + +from torchaudio_unittest.common_utils import ( + TempDirMixin, + PytorchTestCase, + skipIfNoModule, + get_wav_data, + normalize_wav, + load_wav, + save_wav, +) +from .common import ( + parameterize, + dtype2subtype, + skipIfFormatNotSupported, +) + +if _mod_utils.is_module_available("soundfile"): + import soundfile + + +def _get_mock_path( + ext: str, dtype: str, sample_rate: int, num_channels: int, num_frames: int, +): + return f"{dtype}_{sample_rate}_{num_channels}_{num_frames}.{ext}" + + +def _get_mock_params(path: str): + filename, ext = path.split(".") + parts = filename.split("_") + return { + "ext": ext, + "dtype": parts[0], + "sample_rate": int(parts[1]), + "num_channels": int(parts[2]), + "num_frames": int(parts[3]), + } + + +class SoundFileMock: + def __init__(self, path, mode): + assert mode == "r" + self.path = path + self._params = _get_mock_params(path) + self._start = None + + @property + def samplerate(self): + return self._params["sample_rate"] + + @property + def format(self): + if self._params["ext"] == "wav": + return "WAV" + if self._params["ext"] == "flac": + return "FLAC" + if self._params["ext"] == "ogg": + return "OGG" + if self._params["ext"] in ["sph", "nis", "nist"]: + return "NIST" + + @property + def subtype(self): + if self._params["ext"] == "ogg": + return "VORBIS" + return dtype2subtype(self._params["dtype"]) + + def _prepare_read(self, start, stop, frames): + assert stop is None + self._start = start + return frames + + def read(self, frames, dtype, always_2d): + assert always_2d + data = get_wav_data( + dtype, + self._params["num_channels"], + normalize=False, + num_frames=self._params["num_frames"], + channels_first=False, + ).numpy() + return data[self._start:self._start + frames] + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + pass + + +class MockedLoadTest(PytorchTestCase): + def assert_dtype( + self, ext, dtype, sample_rate, num_channels, normalize, channels_first + ): + """When format is WAV or NIST, normalize=False will return the native dtype Tensor, otherwise float32""" + num_frames = 3 * sample_rate + path = _get_mock_path(ext, dtype, sample_rate, num_channels, num_frames) + expected_dtype = ( + torch.float32 + if normalize or ext not in ["wav", "nist"] + else getattr(torch, dtype) + ) + with patch("soundfile.SoundFile", SoundFileMock): + found, sr = soundfile_backend.load( + path, normalize=normalize, channels_first=channels_first + ) + assert found.dtype == expected_dtype + assert sample_rate == sr + + @parameterize( + ["uint8", "int16", "int32", "float32", "float64"], + [8000, 16000], + [1, 2], + [True, False], + [True, False], + ) + def test_wav(self, dtype, sample_rate, num_channels, normalize, channels_first): + """Returns native dtype when normalize=False else float32""" + self.assert_dtype( + "wav", dtype, sample_rate, num_channels, normalize, channels_first + ) + + @parameterize( + ["int8", "int16", "int32"], [8000, 16000], [1, 2], [True, False], [True, False], + ) + def test_sphere(self, dtype, sample_rate, num_channels, normalize, channels_first): + """Returns float32 always""" + self.assert_dtype( + "sph", dtype, sample_rate, num_channels, normalize, channels_first + ) + + @parameterize([8000, 16000], [1, 2], [True, False], [True, False]) + def test_ogg(self, sample_rate, num_channels, normalize, channels_first): + """Returns float32 always""" + self.assert_dtype( + "ogg", "int16", sample_rate, num_channels, normalize, channels_first + ) + + @parameterize([8000, 16000], [1, 2], [True, False], [True, False]) + def test_flac(self, sample_rate, num_channels, normalize, channels_first): + """`soundfile_backend.load` can load ogg format.""" + self.assert_dtype( + "flac", "int16", sample_rate, num_channels, normalize, channels_first + ) + + +class LoadTestBase(TempDirMixin, PytorchTestCase): + def assert_wav( + self, + dtype, + sample_rate, + num_channels, + normalize, + channels_first=True, + duration=1, + ): + """`soundfile_backend.load` can load wav format correctly. + + Wav data loaded with soundfile backend should match those with scipy + """ + path = self.get_temp_path("reference.wav") + num_frames = duration * sample_rate + data = get_wav_data( + dtype, + num_channels, + normalize=normalize, + num_frames=num_frames, + channels_first=channels_first, + ) + save_wav(path, data, sample_rate, channels_first=channels_first) + expected = load_wav(path, normalize=normalize, channels_first=channels_first)[0] + data, sr = soundfile_backend.load( + path, normalize=normalize, channels_first=channels_first + ) + assert sr == sample_rate + self.assertEqual(data, expected) + + def assert_sphere( + self, dtype, sample_rate, num_channels, channels_first=True, duration=1, + ): + """`soundfile_backend.load` can load SPHERE format correctly.""" + path = self.get_temp_path("reference.sph") + num_frames = duration * sample_rate + raw = get_wav_data( + dtype, + num_channels, + num_frames=num_frames, + normalize=False, + channels_first=False, + ) + soundfile.write( + path, raw, sample_rate, subtype=dtype2subtype(dtype), format="NIST" + ) + expected = normalize_wav(raw.t() if channels_first else raw) + data, sr = soundfile_backend.load(path, channels_first=channels_first) + assert sr == sample_rate + self.assertEqual(data, expected, atol=1e-4, rtol=1e-8) + + def assert_flac( + self, dtype, sample_rate, num_channels, channels_first=True, duration=1, + ): + """`soundfile_backend.load` can load FLAC format correctly.""" + path = self.get_temp_path("reference.flac") + num_frames = duration * sample_rate + raw = get_wav_data( + dtype, + num_channels, + num_frames=num_frames, + normalize=False, + channels_first=False, + ) + soundfile.write(path, raw, sample_rate) + expected = normalize_wav(raw.t() if channels_first else raw) + data, sr = soundfile_backend.load(path, channels_first=channels_first) + assert sr == sample_rate + self.assertEqual(data, expected, atol=1e-4, rtol=1e-8) + + +@skipIfNoModule("soundfile") +class TestLoad(LoadTestBase): + """Test the correctness of `soundfile_backend.load` for various formats""" + + @parameterize( + ["float32", "int32", "int16"], + [8000, 16000], + [1, 2], + [False, True], + [False, True], + ) + def test_wav(self, dtype, sample_rate, num_channels, normalize, channels_first): + """`soundfile_backend.load` can load wav format correctly.""" + self.assert_wav(dtype, sample_rate, num_channels, normalize, channels_first) + + @parameterize( + ["int16"], [16000], [2], [False], + ) + def test_wav_large(self, dtype, sample_rate, num_channels, normalize): + """`soundfile_backend.load` can load large wav file correctly.""" + two_hours = 2 * 60 * 60 + self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=two_hours) + + @parameterize(["float32", "int32", "int16"], [4, 8, 16, 32], [False, True]) + def test_multiple_channels(self, dtype, num_channels, channels_first): + """`soundfile_backend.load` can load wav file with more than 2 channels.""" + sample_rate = 8000 + normalize = False + self.assert_wav(dtype, sample_rate, num_channels, normalize, channels_first) + + @parameterize(["int32", "int16"], [8000, 16000], [1, 2], [False, True]) + @skipIfFormatNotSupported("NIST") + def test_sphere(self, dtype, sample_rate, num_channels, channels_first): + """`soundfile_backend.load` can load sphere format correctly.""" + self.assert_sphere(dtype, sample_rate, num_channels, channels_first) + + @parameterize(["int32", "int16"], [8000, 16000], [1, 2], [False, True]) + @skipIfFormatNotSupported("FLAC") + def test_flac(self, dtype, sample_rate, num_channels, channels_first): + """`soundfile_backend.load` can load flac format correctly.""" + self.assert_flac(dtype, sample_rate, num_channels, channels_first) + + +@skipIfNoModule("soundfile") +class TestLoadFormat(TempDirMixin, PytorchTestCase): + """Given `format` parameter, `so.load` can load files without extension""" + original = None + path = None + + def _make_file(self, format_): + sample_rate = 8000 + path_with_ext = self.get_temp_path(f'test.{format_}') + data = get_wav_data('float32', num_channels=2).numpy().T + soundfile.write(path_with_ext, data, sample_rate) + expected = soundfile.read(path_with_ext, dtype='float32')[0].T + path = os.path.splitext(path_with_ext)[0] + os.rename(path_with_ext, path) + return path, expected + + def _test_format(self, format_): + """Providing format allows to read file without extension""" + path, expected = self._make_file(format_) + found, _ = soundfile_backend.load(path) + self.assertEqual(found, expected) + + @parameterized.expand([ + ('WAV', ), ('wav', ), + ]) + def test_wav(self, format_): + self._test_format(format_) + + @parameterized.expand([ + ('FLAC', ), ('flac',), + ]) + @skipIfFormatNotSupported("FLAC") + def test_flac(self, format_): + self._test_format(format_) + + +@skipIfNoModule("soundfile") +class TestFileObject(TempDirMixin, PytorchTestCase): + def _test_fileobj(self, ext): + """Loading audio via file-like object works""" + sample_rate = 16000 + path = self.get_temp_path(f'test.{ext}') + + data = get_wav_data('float32', num_channels=2).numpy().T + soundfile.write(path, data, sample_rate) + expected = soundfile.read(path, dtype='float32')[0].T + + with open(path, 'rb') as fileobj: + found, sr = soundfile_backend.load(fileobj) + assert sr == sample_rate + self.assertEqual(expected, found) + + def test_fileobj_wav(self): + """Loading audio via file-like object works""" + self._test_fileobj('wav') + + @skipIfFormatNotSupported("FLAC") + def test_fileobj_flac(self): + """Loading audio via file-like object works""" + self._test_fileobj('flac') + + def _test_tarfile(self, ext): + """Loading audio via file-like object works""" + sample_rate = 16000 + audio_file = f'test.{ext}' + audio_path = self.get_temp_path(audio_file) + archive_path = self.get_temp_path('archive.tar.gz') + + data = get_wav_data('float32', num_channels=2).numpy().T + soundfile.write(audio_path, data, sample_rate) + expected = soundfile.read(audio_path, dtype='float32')[0].T + + with tarfile.TarFile(archive_path, 'w') as tarobj: + tarobj.add(audio_path, arcname=audio_file) + with tarfile.TarFile(archive_path, 'r') as tarobj: + fileobj = tarobj.extractfile(audio_file) + found, sr = soundfile_backend.load(fileobj) + + assert sr == sample_rate + self.assertEqual(expected, found) + + def test_tarfile_wav(self): + """Loading audio via file-like object works""" + self._test_tarfile('wav') + + @skipIfFormatNotSupported("FLAC") + def test_tarfile_flac(self): + """Loading audio via file-like object works""" + self._test_tarfile('flac') diff --git a/test/torchaudio_unittest/backend/soundfile/save_test.py b/test/torchaudio_unittest/backend/soundfile/save_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e4e8c93631c8d9a36420132cffb54caf9c552273 --- /dev/null +++ b/test/torchaudio_unittest/backend/soundfile/save_test.py @@ -0,0 +1,295 @@ +import io +from unittest.mock import patch + +from torchaudio._internal import module_utils as _mod_utils +from torchaudio.backend import soundfile_backend + +from torchaudio_unittest.common_utils import ( + TempDirMixin, + PytorchTestCase, + skipIfNoModule, + get_wav_data, + load_wav, + nested_params, +) +from .common import ( + fetch_wav_subtype, + parameterize, + skipIfFormatNotSupported, +) + +if _mod_utils.is_module_available("soundfile"): + import soundfile + + +class MockedSaveTest(PytorchTestCase): + @nested_params( + ["float32", "int32", "int16", "uint8"], + [8000, 16000], + [1, 2], + [False, True], + [ + (None, None), + ('PCM_U', None), + ('PCM_U', 8), + ('PCM_S', None), + ('PCM_S', 16), + ('PCM_S', 32), + ('PCM_F', None), + ('PCM_F', 32), + ('PCM_F', 64), + ('ULAW', None), + ('ULAW', 8), + ('ALAW', None), + ('ALAW', 8), + ], + ) + @patch("soundfile.write") + def test_wav(self, dtype, sample_rate, num_channels, channels_first, + enc_params, mocked_write): + """soundfile_backend.save passes correct subtype to soundfile.write when WAV""" + filepath = "foo.wav" + input_tensor = get_wav_data( + dtype, + num_channels, + num_frames=3 * sample_rate, + normalize=dtype == "float32", + channels_first=channels_first, + ).t() + + encoding, bits_per_sample = enc_params + soundfile_backend.save( + filepath, input_tensor, sample_rate, channels_first=channels_first, + encoding=encoding, bits_per_sample=bits_per_sample + ) + + # on +Py3.8 call_args.kwargs is more descreptive + args = mocked_write.call_args[1] + assert args["file"] == filepath + assert args["samplerate"] == sample_rate + assert args["subtype"] == fetch_wav_subtype( + dtype, encoding, bits_per_sample) + assert args["format"] is None + self.assertEqual( + args["data"], input_tensor.t() if channels_first else input_tensor + ) + + @patch("soundfile.write") + def assert_non_wav( + self, fmt, dtype, sample_rate, num_channels, channels_first, mocked_write, + encoding=None, bits_per_sample=None, + ): + """soundfile_backend.save passes correct subtype and format to soundfile.write when SPHERE""" + filepath = f"foo.{fmt}" + input_tensor = get_wav_data( + dtype, + num_channels, + num_frames=3 * sample_rate, + normalize=False, + channels_first=channels_first, + ).t() + expected_data = input_tensor.t() if channels_first else input_tensor + + soundfile_backend.save( + filepath, input_tensor, sample_rate, channels_first, + encoding=encoding, bits_per_sample=bits_per_sample, + ) + + # on +Py3.8 call_args.kwargs is more descreptive + args = mocked_write.call_args[1] + assert args["file"] == filepath + assert args["samplerate"] == sample_rate + if fmt in ["sph", "nist", "nis"]: + assert args["format"] == "NIST" + else: + assert args["format"] is None + self.assertEqual(args["data"], expected_data) + + @nested_params( + ["sph", "nist", "nis"], + ["int32", "int16"], + [8000, 16000], + [1, 2], + [False, True], + [ + ('PCM_S', 8), + ('PCM_S', 16), + ('PCM_S', 24), + ('PCM_S', 32), + ('ULAW', 8), + ('ALAW', 8), + ('ALAW', 16), + ('ALAW', 24), + ('ALAW', 32), + ], + ) + def test_sph(self, fmt, dtype, sample_rate, num_channels, channels_first, enc_params): + """soundfile_backend.save passes default format and subtype (None-s) to + soundfile.write when not WAV""" + encoding, bits_per_sample = enc_params + self.assert_non_wav(fmt, dtype, sample_rate, num_channels, + channels_first, encoding=encoding, + bits_per_sample=bits_per_sample) + + @parameterize( + ["int32", "int16"], [8000, 16000], [1, 2], [False, True], + [8, 16, 24], + ) + def test_flac(self, dtype, sample_rate, num_channels, + channels_first, bits_per_sample): + """soundfile_backend.save passes default format and subtype (None-s) to + soundfile.write when not WAV""" + self.assert_non_wav("flac", dtype, sample_rate, num_channels, + channels_first, bits_per_sample=bits_per_sample) + + @parameterize( + ["int32", "int16"], [8000, 16000], [1, 2], [False, True], + ) + def test_ogg(self, dtype, sample_rate, num_channels, channels_first): + """soundfile_backend.save passes default format and subtype (None-s) to + soundfile.write when not WAV""" + self.assert_non_wav("ogg", dtype, sample_rate, num_channels, channels_first) + + +@skipIfNoModule("soundfile") +class SaveTestBase(TempDirMixin, PytorchTestCase): + def assert_wav(self, dtype, sample_rate, num_channels, num_frames): + """`soundfile_backend.save` can save wav format.""" + path = self.get_temp_path("data.wav") + expected = get_wav_data( + dtype, num_channels, num_frames=num_frames, normalize=False + ) + soundfile_backend.save(path, expected, sample_rate) + found, sr = load_wav(path, normalize=False) + assert sample_rate == sr + self.assertEqual(found, expected) + + def _assert_non_wav(self, fmt, dtype, sample_rate, num_channels): + """`soundfile_backend.save` can save non-wav format. + + Due to precision missmatch, and the lack of alternative way to decode the + resulting files without using soundfile, only meta data are validated. + """ + num_frames = sample_rate * 3 + path = self.get_temp_path(f"data.{fmt}") + expected = get_wav_data( + dtype, num_channels, num_frames=num_frames, normalize=False + ) + soundfile_backend.save(path, expected, sample_rate) + sinfo = soundfile.info(path) + assert sinfo.format == fmt.upper() + assert sinfo.frames == num_frames + assert sinfo.channels == num_channels + assert sinfo.samplerate == sample_rate + + def assert_flac(self, dtype, sample_rate, num_channels): + """`soundfile_backend.save` can save flac format.""" + self._assert_non_wav("flac", dtype, sample_rate, num_channels) + + def assert_sphere(self, dtype, sample_rate, num_channels): + """`soundfile_backend.save` can save sph format.""" + self._assert_non_wav("nist", dtype, sample_rate, num_channels) + + def assert_ogg(self, dtype, sample_rate, num_channels): + """`soundfile_backend.save` can save ogg format. + + As we cannot inspect the OGG format (it's lossy), we only check the metadata. + """ + self._assert_non_wav("ogg", dtype, sample_rate, num_channels) + + +@skipIfNoModule("soundfile") +class TestSave(SaveTestBase): + @parameterize( + ["float32", "int32", "int16"], [8000, 16000], [1, 2], + ) + def test_wav(self, dtype, sample_rate, num_channels): + """`soundfile_backend.save` can save wav format.""" + self.assert_wav(dtype, sample_rate, num_channels, num_frames=None) + + @parameterize( + ["float32", "int32", "int16"], [4, 8, 16, 32], + ) + def test_multiple_channels(self, dtype, num_channels): + """`soundfile_backend.save` can save wav with more than 2 channels.""" + sample_rate = 8000 + self.assert_wav(dtype, sample_rate, num_channels, num_frames=None) + + @parameterize( + ["int32", "int16"], [8000, 16000], [1, 2], + ) + @skipIfFormatNotSupported("NIST") + def test_sphere(self, dtype, sample_rate, num_channels): + """`soundfile_backend.save` can save sph format.""" + self.assert_sphere(dtype, sample_rate, num_channels) + + @parameterize( + [8000, 16000], [1, 2], + ) + @skipIfFormatNotSupported("FLAC") + def test_flac(self, sample_rate, num_channels): + """`soundfile_backend.save` can save flac format.""" + self.assert_flac("float32", sample_rate, num_channels) + + @parameterize( + [8000, 16000], [1, 2], + ) + @skipIfFormatNotSupported("OGG") + def test_ogg(self, sample_rate, num_channels): + """`soundfile_backend.save` can save ogg/vorbis format.""" + self.assert_ogg("float32", sample_rate, num_channels) + + +@skipIfNoModule("soundfile") +class TestSaveParams(TempDirMixin, PytorchTestCase): + """Test the correctness of optional parameters of `soundfile_backend.save`""" + + @parameterize([True, False]) + def test_channels_first(self, channels_first): + """channels_first swaps axes""" + path = self.get_temp_path("data.wav") + data = get_wav_data("int32", 2, channels_first=channels_first) + soundfile_backend.save(path, data, 8000, channels_first=channels_first) + found = load_wav(path)[0] + expected = data if channels_first else data.transpose(1, 0) + self.assertEqual(found, expected, atol=1e-4, rtol=1e-8) + + +@skipIfNoModule("soundfile") +class TestFileObject(TempDirMixin, PytorchTestCase): + def _test_fileobj(self, ext): + """Saving audio to file-like object works""" + sample_rate = 16000 + path = self.get_temp_path(f'test.{ext}') + + subtype = 'FLOAT' if ext == 'wav' else None + data = get_wav_data('float32', num_channels=2) + soundfile.write(path, data.numpy().T, sample_rate, subtype=subtype) + expected = soundfile.read(path, dtype='float32')[0] + + fileobj = io.BytesIO() + soundfile_backend.save(fileobj, data, sample_rate, format=ext) + fileobj.seek(0) + found, sr = soundfile.read(fileobj, dtype='float32') + + assert sr == sample_rate + self.assertEqual(expected, found, atol=1e-4, rtol=1e-8) + + def test_fileobj_wav(self): + """Saving audio via file-like object works""" + self._test_fileobj('wav') + + @skipIfFormatNotSupported("FLAC") + def test_fileobj_flac(self): + """Saving audio via file-like object works""" + self._test_fileobj('flac') + + @skipIfFormatNotSupported("NIST") + def test_fileobj_nist(self): + """Saving audio via file-like object works""" + self._test_fileobj('NIST') + + @skipIfFormatNotSupported("OGG") + def test_fileobj_ogg(self): + """Saving audio via file-like object works""" + self._test_fileobj('OGG') diff --git a/test/torchaudio_unittest/backend/sox_io/__init__.py b/test/torchaudio_unittest/backend/sox_io/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/torchaudio_unittest/backend/sox_io/common.py b/test/torchaudio_unittest/backend/sox_io/common.py new file mode 100644 index 0000000000000000000000000000000000000000..c2538b2bc41d1c3c67356ea83e5238d45f37f4f2 --- /dev/null +++ b/test/torchaudio_unittest/backend/sox_io/common.py @@ -0,0 +1,14 @@ +def name_func(func, _, params): + return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}' + + +def get_enc_params(dtype): + if dtype == 'float32': + return 'PCM_F', 32 + if dtype == 'int32': + return 'PCM_S', 32 + if dtype == 'int16': + return 'PCM_S', 16 + if dtype == 'uint8': + return 'PCM_U', 8 + raise ValueError(f'Unexpected dtype: {dtype}') diff --git a/test/torchaudio_unittest/backend/sox_io/info_test.py b/test/torchaudio_unittest/backend/sox_io/info_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0405d1ba9cee13076f151526aa4cf34c83d7c195 --- /dev/null +++ b/test/torchaudio_unittest/backend/sox_io/info_test.py @@ -0,0 +1,537 @@ +from contextlib import contextmanager +import io +import os +import itertools +import tarfile + +from parameterized import parameterized +from torchaudio.backend import sox_io_backend +from torchaudio.utils.sox_utils import get_buffer_size, set_buffer_size +from torchaudio._internal import module_utils as _mod_utils + +from torchaudio_unittest.backend.common import ( + get_bits_per_sample, + get_encoding, +) +from torchaudio_unittest.common_utils import ( + TempDirMixin, + HttpServerMixin, + PytorchTestCase, + skipIfNoExec, + skipIfNoModule, + skipIfNoSox, + get_asset_path, + get_wav_data, + save_wav, + sox_utils, +) +from .common import ( + name_func, +) + + +if _mod_utils.is_module_available("requests"): + import requests + + +@skipIfNoExec('sox') +@skipIfNoSox +class TestInfo(TempDirMixin, PytorchTestCase): + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16', 'uint8'], + [8000, 16000], + [1, 2], + )), name_func=name_func) + def test_wav(self, dtype, sample_rate, num_channels): + """`sox_io_backend.info` can check wav file correctly""" + duration = 1 + path = self.get_temp_path('data.wav') + data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) + save_wav(path, data, sample_rate) + info = sox_io_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + assert info.bits_per_sample == sox_utils.get_bit_depth(dtype) + assert info.encoding == get_encoding('wav', dtype) + + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16', 'uint8'], + [8000, 16000], + [4, 8, 16, 32], + )), name_func=name_func) + def test_wav_multiple_channels(self, dtype, sample_rate, num_channels): + """`sox_io_backend.info` can check wav file with channels more than 2 correctly""" + duration = 1 + path = self.get_temp_path('data.wav') + data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) + save_wav(path, data, sample_rate) + info = sox_io_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + assert info.bits_per_sample == sox_utils.get_bit_depth(dtype) + assert info.encoding == get_encoding('wav', dtype) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + [96, 128, 160, 192, 224, 256, 320], + )), name_func=name_func) + def test_mp3(self, sample_rate, num_channels, bit_rate): + """`sox_io_backend.info` can check mp3 file correctly""" + duration = 1 + path = self.get_temp_path('data.mp3') + sox_utils.gen_audio_file( + path, sample_rate, num_channels, + compression=bit_rate, duration=duration, + ) + info = sox_io_backend.info(path) + assert info.sample_rate == sample_rate + # mp3 does not preserve the number of samples + # assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats + assert info.encoding == "MP3" + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + list(range(9)), + )), name_func=name_func) + def test_flac(self, sample_rate, num_channels, compression_level): + """`sox_io_backend.info` can check flac file correctly""" + duration = 1 + path = self.get_temp_path('data.flac') + sox_utils.gen_audio_file( + path, sample_rate, num_channels, + compression=compression_level, duration=duration, + ) + info = sox_io_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + assert info.bits_per_sample == 24 # FLAC standard + assert info.encoding == "FLAC" + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + [-1, 0, 1, 2, 3, 3.6, 5, 10], + )), name_func=name_func) + def test_vorbis(self, sample_rate, num_channels, quality_level): + """`sox_io_backend.info` can check vorbis file correctly""" + duration = 1 + path = self.get_temp_path('data.vorbis') + sox_utils.gen_audio_file( + path, sample_rate, num_channels, + compression=quality_level, duration=duration, + ) + info = sox_io_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats + assert info.encoding == "VORBIS" + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + [16, 32], + )), name_func=name_func) + def test_sphere(self, sample_rate, num_channels, bits_per_sample): + """`sox_io_backend.info` can check sph file correctly""" + duration = 1 + path = self.get_temp_path('data.sph') + sox_utils.gen_audio_file( + path, sample_rate, num_channels, duration=duration, + bit_depth=bits_per_sample) + info = sox_io_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + assert info.bits_per_sample == bits_per_sample + assert info.encoding == "PCM_S" + + @parameterized.expand(list(itertools.product( + ['int32', 'int16', 'uint8'], + [8000, 16000], + [1, 2], + )), name_func=name_func) + def test_amb(self, dtype, sample_rate, num_channels): + """`sox_io_backend.info` can check amb file correctly""" + duration = 1 + path = self.get_temp_path('data.amb') + bits_per_sample = sox_utils.get_bit_depth(dtype) + sox_utils.gen_audio_file( + path, sample_rate, num_channels, + bit_depth=bits_per_sample, duration=duration) + info = sox_io_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + assert info.bits_per_sample == bits_per_sample + assert info.encoding == get_encoding("amb", dtype) + + def test_amr_nb(self): + """`sox_io_backend.info` can check amr-nb file correctly""" + duration = 1 + num_channels = 1 + sample_rate = 8000 + path = self.get_temp_path('data.amr-nb') + sox_utils.gen_audio_file( + path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=16, + duration=duration) + info = sox_io_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + assert info.bits_per_sample == 0 + assert info.encoding == "AMR_NB" + + def test_ulaw(self): + """`sox_io_backend.info` can check ulaw file correctly""" + duration = 1 + num_channels = 1 + sample_rate = 8000 + path = self.get_temp_path('data.wav') + sox_utils.gen_audio_file( + path, sample_rate=sample_rate, num_channels=num_channels, + bit_depth=8, encoding='u-law', + duration=duration) + info = sox_io_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + assert info.bits_per_sample == 8 + assert info.encoding == "ULAW" + + def test_alaw(self): + """`sox_io_backend.info` can check alaw file correctly""" + duration = 1 + num_channels = 1 + sample_rate = 8000 + path = self.get_temp_path('data.wav') + sox_utils.gen_audio_file( + path, sample_rate=sample_rate, num_channels=num_channels, + bit_depth=8, encoding='a-law', + duration=duration) + info = sox_io_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + assert info.bits_per_sample == 8 + assert info.encoding == "ALAW" + + def test_gsm(self): + """`sox_io_backend.info` can check gsm file correctly""" + duration = 1 + num_channels = 1 + sample_rate = 8000 + path = self.get_temp_path('data.gsm') + sox_utils.gen_audio_file( + path, sample_rate=sample_rate, num_channels=num_channels, + duration=duration) + info = sox_io_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_channels == num_channels + assert info.bits_per_sample == 0 + assert info.encoding == "GSM" + + def test_htk(self): + """`sox_io_backend.info` can check HTK file correctly""" + duration = 1 + num_channels = 1 + sample_rate = 8000 + path = self.get_temp_path('data.htk') + sox_utils.gen_audio_file( + path, sample_rate=sample_rate, num_channels=num_channels, + bit_depth=16, duration=duration) + info = sox_io_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + assert info.bits_per_sample == 16 + assert info.encoding == "PCM_S" + + +@skipIfNoSox +class TestInfoOpus(PytorchTestCase): + @parameterized.expand(list(itertools.product( + ['96k'], + [1, 2], + [0, 5, 10], + )), name_func=name_func) + def test_opus(self, bitrate, num_channels, compression_level): + """`sox_io_backend.info` can check opus file correcty""" + path = get_asset_path('io', f'{bitrate}_{compression_level}_{num_channels}ch.opus') + info = sox_io_backend.info(path) + assert info.sample_rate == 48000 + assert info.num_frames == 32768 + assert info.num_channels == num_channels + assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats + assert info.encoding == "OPUS" + + +@skipIfNoSox +class TestLoadWithoutExtension(PytorchTestCase): + def test_mp3(self): + """Providing `format` allows to read mp3 without extension + + libsox does not check header for mp3 + + https://github.com/pytorch/audio/issues/1040 + + The file was generated with the following command + ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext + """ + path = get_asset_path("mp3_without_ext") + sinfo = sox_io_backend.info(path, format="mp3") + assert sinfo.sample_rate == 16000 + assert sinfo.num_frames == 81216 + assert sinfo.num_channels == 1 + assert sinfo.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats + assert sinfo.encoding == "MP3" + + +class FileObjTestBase(TempDirMixin): + def _gen_file(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None): + path = self.get_temp_path(f'test.{ext}') + bit_depth = sox_utils.get_bit_depth(dtype) + duration = num_frames / sample_rate + comment_file = self._gen_comment_file(comments) if comments else None + + sox_utils.gen_audio_file( + path, sample_rate, num_channels=num_channels, + encoding=sox_utils.get_encoding(dtype), + bit_depth=bit_depth, + duration=duration, + comment_file=comment_file, + ) + return path + + def _gen_comment_file(self, comments): + comment_path = self.get_temp_path("comment.txt") + with open(comment_path, "w") as file_: + file_.writelines(comments) + return comment_path + + +@skipIfNoSox +@skipIfNoExec('sox') +class TestFileObject(FileObjTestBase, PytorchTestCase): + def _query_fileobj(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None): + path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames, comments=comments) + format_ = ext if ext in ['mp3'] else None + with open(path, 'rb') as fileobj: + return sox_io_backend.info(fileobj, format_) + + def _query_bytesio(self, ext, dtype, sample_rate, num_channels, num_frames): + path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames) + format_ = ext if ext in ['mp3'] else None + with open(path, 'rb') as file_: + fileobj = io.BytesIO(file_.read()) + return sox_io_backend.info(fileobj, format_) + + def _query_tarfile(self, ext, dtype, sample_rate, num_channels, num_frames): + audio_path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames) + audio_file = os.path.basename(audio_path) + archive_path = self.get_temp_path('archive.tar.gz') + with tarfile.TarFile(archive_path, 'w') as tarobj: + tarobj.add(audio_path, arcname=audio_file) + format_ = ext if ext in ['mp3'] else None + with tarfile.TarFile(archive_path, 'r') as tarobj: + fileobj = tarobj.extractfile(audio_file) + return sox_io_backend.info(fileobj, format_) + + @contextmanager + def _set_buffer_size(self, buffer_size): + try: + original_buffer_size = get_buffer_size() + set_buffer_size(buffer_size) + yield + finally: + set_buffer_size(original_buffer_size) + + @parameterized.expand([ + ('wav', "float32"), + ('wav', "int32"), + ('wav', "int16"), + ('wav', "uint8"), + ('mp3', "float32"), + ('flac', "float32"), + ('vorbis', "float32"), + ('amb', "int16"), + ]) + def test_fileobj(self, ext, dtype): + """Querying audio via file object works""" + sample_rate = 16000 + num_frames = 3 * sample_rate + num_channels = 2 + sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames) + + bits_per_sample = get_bits_per_sample(ext, dtype) + num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames + + assert sinfo.sample_rate == sample_rate + assert sinfo.num_channels == num_channels + assert sinfo.num_frames == num_frames + assert sinfo.bits_per_sample == bits_per_sample + assert sinfo.encoding == get_encoding(ext, dtype) + + @parameterized.expand([ + ('vorbis', "float32"), + ]) + def test_fileobj_large_header(self, ext, dtype): + """ + For audio file with header size exceeding default buffer size: + - Querying audio via file object without enlarging buffer size fails. + - Querying audio via file object after enlarging buffer size succeeds. + """ + sample_rate = 16000 + num_frames = 3 * sample_rate + num_channels = 2 + comments = "metadata=" + " ".join(["value" for _ in range(1000)]) + + with self.assertRaisesRegex(RuntimeError, "^Error loading audio file:"): + sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments) + + with self._set_buffer_size(16384): + sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments) + bits_per_sample = get_bits_per_sample(ext, dtype) + num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames + + assert sinfo.sample_rate == sample_rate + assert sinfo.num_channels == num_channels + assert sinfo.num_frames == num_frames + assert sinfo.bits_per_sample == bits_per_sample + assert sinfo.encoding == get_encoding(ext, dtype) + + @parameterized.expand([ + ('wav', "float32"), + ('wav', "int32"), + ('wav', "int16"), + ('wav', "uint8"), + ('mp3', "float32"), + ('flac', "float32"), + ('vorbis', "float32"), + ('amb', "int16"), + ]) + def test_bytesio(self, ext, dtype): + """Querying audio via ByteIO object works for small data""" + sample_rate = 16000 + num_frames = 3 * sample_rate + num_channels = 2 + sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames) + + bits_per_sample = get_bits_per_sample(ext, dtype) + num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames + + assert sinfo.sample_rate == sample_rate + assert sinfo.num_channels == num_channels + assert sinfo.num_frames == num_frames + assert sinfo.bits_per_sample == bits_per_sample + assert sinfo.encoding == get_encoding(ext, dtype) + + @parameterized.expand([ + ('wav', "float32"), + ('wav', "int32"), + ('wav', "int16"), + ('wav', "uint8"), + ('mp3', "float32"), + ('flac', "float32"), + ('vorbis', "float32"), + ('amb', "int16"), + ]) + def test_bytesio_tiny(self, ext, dtype): + """Querying audio via ByteIO object works for small data""" + sample_rate = 8000 + num_frames = 4 + num_channels = 2 + sinfo = self._query_bytesio(ext, dtype, sample_rate, num_channels, num_frames) + + bits_per_sample = get_bits_per_sample(ext, dtype) + num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames + + assert sinfo.sample_rate == sample_rate + assert sinfo.num_channels == num_channels + assert sinfo.num_frames == num_frames + assert sinfo.bits_per_sample == bits_per_sample + assert sinfo.encoding == get_encoding(ext, dtype) + + @parameterized.expand([ + ('wav', "float32"), + ('wav', "int32"), + ('wav', "int16"), + ('wav', "uint8"), + ('mp3', "float32"), + ('flac', "float32"), + ('vorbis', "float32"), + ('amb', "int16"), + ]) + def test_tarfile(self, ext, dtype): + """Querying compressed audio via file-like object works""" + sample_rate = 16000 + num_frames = 3.0 * sample_rate + num_channels = 2 + sinfo = self._query_tarfile(ext, dtype, sample_rate, num_channels, num_frames) + + bits_per_sample = get_bits_per_sample(ext, dtype) + num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames + + assert sinfo.sample_rate == sample_rate + assert sinfo.num_channels == num_channels + assert sinfo.num_frames == num_frames + assert sinfo.bits_per_sample == bits_per_sample + assert sinfo.encoding == get_encoding(ext, dtype) + + +@skipIfNoSox +@skipIfNoExec('sox') +@skipIfNoModule("requests") +class TestFileObjectHttp(HttpServerMixin, FileObjTestBase, PytorchTestCase): + def _query_http(self, ext, dtype, sample_rate, num_channels, num_frames): + audio_path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames) + audio_file = os.path.basename(audio_path) + + url = self.get_url(audio_file) + format_ = ext if ext in ['mp3'] else None + with requests.get(url, stream=True) as resp: + return sox_io_backend.info(resp.raw, format=format_) + + @parameterized.expand([ + ('wav', "float32"), + ('wav', "int32"), + ('wav', "int16"), + ('wav', "uint8"), + ('mp3', "float32"), + ('flac', "float32"), + ('vorbis', "float32"), + ('amb', "int16"), + ]) + def test_requests(self, ext, dtype): + """Querying compressed audio via requests works""" + sample_rate = 16000 + num_frames = 3.0 * sample_rate + num_channels = 2 + sinfo = self._query_http(ext, dtype, sample_rate, num_channels, num_frames) + + bits_per_sample = get_bits_per_sample(ext, dtype) + num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames + + assert sinfo.sample_rate == sample_rate + assert sinfo.num_channels == num_channels + assert sinfo.num_frames == num_frames + assert sinfo.bits_per_sample == bits_per_sample + assert sinfo.encoding == get_encoding(ext, dtype) + + +@skipIfNoSox +class TestInfoNoSuchFile(PytorchTestCase): + def test_info_fail(self): + """ + When attempted to get info on a non-existing file, error message must contain the file path. + """ + path = "non_existing_audio.wav" + with self.assertRaisesRegex(RuntimeError, "^Error loading audio file: failed to open file {0}$".format(path)): + sox_io_backend.info(path) diff --git a/test/torchaudio_unittest/backend/sox_io/load_test.py b/test/torchaudio_unittest/backend/sox_io/load_test.py new file mode 100644 index 0000000000000000000000000000000000000000..824d012cfd3298ff462d0c30c85b6e8cce892e4c --- /dev/null +++ b/test/torchaudio_unittest/backend/sox_io/load_test.py @@ -0,0 +1,535 @@ +import io +import itertools +import tarfile + +from parameterized import parameterized +from torchaudio.backend import sox_io_backend +from torchaudio._internal import module_utils as _mod_utils + +from torchaudio_unittest.common_utils import ( + TempDirMixin, + HttpServerMixin, + PytorchTestCase, + skipIfNoExec, + skipIfNoModule, + skipIfNoSox, + get_asset_path, + get_wav_data, + load_wav, + save_wav, + sox_utils, +) +from .common import ( + name_func, +) + + +if _mod_utils.is_module_available("requests"): + import requests + + +class LoadTestBase(TempDirMixin, PytorchTestCase): + def assert_format( + self, + format: str, + sample_rate: float, + num_channels: int, + compression: float = None, + bit_depth: int = None, + duration: float = 1, + normalize: bool = True, + encoding: str = None, + atol: float = 4e-05, + rtol: float = 1.3e-06, + ): + """`sox_io_backend.load` can load given format correctly. + + file encodings introduce delay and boundary effects so + we create a reference wav file from the original file format + + x + | + | 1. Generate given format with Sox + | + v 2. Convert to wav with Sox + given format ----------------------> wav + | | + | 3. Load with torchaudio | 4. Load with scipy + | | + v v + tensor ----------> x <----------- tensor + 5. Compare + + Underlying assumptions are; + i. Conversion of given format to wav with Sox preserves data. + ii. Loading wav file with scipy is correct. + + By combining i & ii, step 2. and 4. allows to load reference given format + data without using torchaudio + """ + + path = self.get_temp_path(f'1.original.{format}') + ref_path = self.get_temp_path('2.reference.wav') + + # 1. Generate the given format with sox + sox_utils.gen_audio_file( + path, sample_rate, num_channels, encoding=encoding, + compression=compression, bit_depth=bit_depth, duration=duration, + ) + # 2. Convert to wav with sox + wav_bit_depth = 32 if bit_depth == 24 else None # for 24-bit wav + sox_utils.convert_audio_file(path, ref_path, bit_depth=wav_bit_depth) + # 3. Load the given format with torchaudio + data, sr = sox_io_backend.load(path, normalize=normalize) + # 4. Load wav with scipy + data_ref = load_wav(ref_path, normalize=normalize)[0] + # 5. Compare + assert sr == sample_rate + self.assertEqual(data, data_ref, atol=atol, rtol=rtol) + + def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration): + """`sox_io_backend.load` can load wav format correctly. + + Wav data loaded with sox_io backend should match those with scipy + """ + path = self.get_temp_path('reference.wav') + data = get_wav_data(dtype, num_channels, normalize=normalize, num_frames=duration * sample_rate) + save_wav(path, data, sample_rate) + expected = load_wav(path, normalize=normalize)[0] + data, sr = sox_io_backend.load(path, normalize=normalize) + assert sr == sample_rate + self.assertEqual(data, expected) + + +@skipIfNoExec('sox') +@skipIfNoSox +class TestLoad(LoadTestBase): + """Test the correctness of `sox_io_backend.load` for various formats""" + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16', 'uint8'], + [8000, 16000], + [1, 2], + [False, True], + )), name_func=name_func) + def test_wav(self, dtype, sample_rate, num_channels, normalize): + """`sox_io_backend.load` can load wav format correctly.""" + self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + [False, True], + )), name_func=name_func) + def test_24bit_wav(self, sample_rate, num_channels, normalize): + """`sox_io_backend.load` can load 24bit wav format correctly. Corectly casts it to ``int32`` tensor dtype.""" + self.assert_format("wav", sample_rate, num_channels, bit_depth=24, normalize=normalize, duration=1) + + @parameterized.expand(list(itertools.product( + ['int16'], + [16000], + [2], + [False], + )), name_func=name_func) + def test_wav_large(self, dtype, sample_rate, num_channels, normalize): + """`sox_io_backend.load` can load large wav file correctly.""" + two_hours = 2 * 60 * 60 + self.assert_wav(dtype, sample_rate, num_channels, normalize, two_hours) + + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16', 'uint8'], + [4, 8, 16, 32], + )), name_func=name_func) + def test_multiple_channels(self, dtype, num_channels): + """`sox_io_backend.load` can load wav file with more than 2 channels.""" + sample_rate = 8000 + normalize = False + self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1) + + @parameterized.expand(list(itertools.product( + [8000, 16000, 44100], + [1, 2], + [96, 128, 160, 192, 224, 256, 320], + )), name_func=name_func) + def test_mp3(self, sample_rate, num_channels, bit_rate): + """`sox_io_backend.load` can load mp3 format correctly.""" + self.assert_format("mp3", sample_rate, num_channels, compression=bit_rate, duration=1, atol=5e-05) + + @parameterized.expand(list(itertools.product( + [16000], + [2], + [128], + )), name_func=name_func) + def test_mp3_large(self, sample_rate, num_channels, bit_rate): + """`sox_io_backend.load` can load large mp3 file correctly.""" + two_hours = 2 * 60 * 60 + self.assert_format("mp3", sample_rate, num_channels, compression=bit_rate, duration=two_hours, atol=5e-05) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + list(range(9)), + )), name_func=name_func) + def test_flac(self, sample_rate, num_channels, compression_level): + """`sox_io_backend.load` can load flac format correctly.""" + self.assert_format("flac", sample_rate, num_channels, compression=compression_level, bit_depth=16, duration=1) + + @parameterized.expand(list(itertools.product( + [16000], + [2], + [0], + )), name_func=name_func) + def test_flac_large(self, sample_rate, num_channels, compression_level): + """`sox_io_backend.load` can load large flac file correctly.""" + two_hours = 2 * 60 * 60 + self.assert_format( + "flac", sample_rate, num_channels, compression=compression_level, bit_depth=16, duration=two_hours) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + [-1, 0, 1, 2, 3, 3.6, 5, 10], + )), name_func=name_func) + def test_vorbis(self, sample_rate, num_channels, quality_level): + """`sox_io_backend.load` can load vorbis format correctly.""" + self.assert_format("vorbis", sample_rate, num_channels, compression=quality_level, bit_depth=16, duration=1) + + @parameterized.expand(list(itertools.product( + [16000], + [2], + [10], + )), name_func=name_func) + def test_vorbis_large(self, sample_rate, num_channels, quality_level): + """`sox_io_backend.load` can load large vorbis file correctly.""" + two_hours = 2 * 60 * 60 + self.assert_format( + "vorbis", sample_rate, num_channels, compression=quality_level, bit_depth=16, duration=two_hours) + + @parameterized.expand(list(itertools.product( + ['96k'], + [1, 2], + [0, 5, 10], + )), name_func=name_func) + def test_opus(self, bitrate, num_channels, compression_level): + """`sox_io_backend.load` can load opus file correctly.""" + ops_path = get_asset_path('io', f'{bitrate}_{compression_level}_{num_channels}ch.opus') + wav_path = self.get_temp_path(f'{bitrate}_{compression_level}_{num_channels}ch.opus.wav') + sox_utils.convert_audio_file(ops_path, wav_path) + + expected, sample_rate = load_wav(wav_path) + found, sr = sox_io_backend.load(ops_path) + + assert sample_rate == sr + self.assertEqual(expected, found) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + )), name_func=name_func) + def test_sphere(self, sample_rate, num_channels): + """`sox_io_backend.load` can load sph format correctly.""" + self.assert_format("sph", sample_rate, num_channels, bit_depth=32, duration=1) + + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16'], + [8000, 16000], + [1, 2], + [False, True], + )), name_func=name_func) + def test_amb(self, dtype, sample_rate, num_channels, normalize): + """`sox_io_backend.load` can load amb format correctly.""" + bit_depth = sox_utils.get_bit_depth(dtype) + encoding = sox_utils.get_encoding(dtype) + self.assert_format( + "amb", sample_rate, num_channels, bit_depth=bit_depth, duration=1, encoding=encoding, normalize=normalize) + + def test_amr_nb(self): + """`sox_io_backend.load` can load amr_nb format correctly.""" + self.assert_format("amr-nb", sample_rate=8000, num_channels=1, bit_depth=32, duration=1) + + +@skipIfNoExec('sox') +@skipIfNoSox +class TestLoadParams(TempDirMixin, PytorchTestCase): + """Test the correctness of frame parameters of `sox_io_backend.load`""" + original = None + path = None + + def setUp(self): + super().setUp() + sample_rate = 8000 + self.original = get_wav_data('float32', num_channels=2) + self.path = self.get_temp_path('test.wav') + save_wav(self.path, self.original, sample_rate) + + @parameterized.expand(list(itertools.product( + [0, 1, 10, 100, 1000], + [-1, 1, 10, 100, 1000], + )), name_func=name_func) + def test_frame(self, frame_offset, num_frames): + """num_frames and frame_offset correctly specify the region of data""" + found, _ = sox_io_backend.load(self.path, frame_offset, num_frames) + frame_end = None if num_frames == -1 else frame_offset + num_frames + self.assertEqual(found, self.original[:, frame_offset:frame_end]) + + @parameterized.expand([(True, ), (False, )], name_func=name_func) + def test_channels_first(self, channels_first): + """channels_first swaps axes""" + found, _ = sox_io_backend.load(self.path, channels_first=channels_first) + expected = self.original if channels_first else self.original.transpose(1, 0) + self.assertEqual(found, expected) + + +@skipIfNoSox +class TestLoadWithoutExtension(PytorchTestCase): + def test_mp3(self): + """Providing format allows to read mp3 without extension + + libsox does not check header for mp3 + + https://github.com/pytorch/audio/issues/1040 + + The file was generated with the following command + ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext + """ + path = get_asset_path("mp3_without_ext") + _, sr = sox_io_backend.load(path, format="mp3") + assert sr == 16000 + + +class CloggedFileObj: + def __init__(self, fileobj): + self.fileobj = fileobj + self.buffer = b'' + + def read(self, n): + if not self.buffer: + self.buffer += self.fileobj.read(n) + ret = self.buffer[:2] + self.buffer = self.buffer[2:] + return ret + + +@skipIfNoSox +@skipIfNoExec('sox') +class TestFileObject(TempDirMixin, PytorchTestCase): + """ + In this test suite, the result of file-like object input is compared against file path input, + because `load` function is rigrously tested for file path inputs to match libsox's result, + """ + @parameterized.expand([ + ('wav', None), + ('mp3', 128), + ('mp3', 320), + ('flac', 0), + ('flac', 5), + ('flac', 8), + ('vorbis', -1), + ('vorbis', 10), + ('amb', None), + ]) + def test_fileobj(self, ext, compression): + """Loading audio via file object returns the same result as via file path.""" + sample_rate = 16000 + format_ = ext if ext in ['mp3'] else None + path = self.get_temp_path(f'test.{ext}') + + sox_utils.gen_audio_file( + path, sample_rate, num_channels=2, + compression=compression) + expected, _ = sox_io_backend.load(path) + + with open(path, 'rb') as fileobj: + found, sr = sox_io_backend.load(fileobj, format=format_) + + assert sr == sample_rate + self.assertEqual(expected, found) + + @parameterized.expand([ + ('wav', None), + ('mp3', 128), + ('mp3', 320), + ('flac', 0), + ('flac', 5), + ('flac', 8), + ('vorbis', -1), + ('vorbis', 10), + ('amb', None), + ]) + def test_bytesio(self, ext, compression): + """Loading audio via BytesIO object returns the same result as via file path.""" + sample_rate = 16000 + format_ = ext if ext in ['mp3'] else None + path = self.get_temp_path(f'test.{ext}') + + sox_utils.gen_audio_file( + path, sample_rate, num_channels=2, + compression=compression) + expected, _ = sox_io_backend.load(path) + + with open(path, 'rb') as file_: + fileobj = io.BytesIO(file_.read()) + found, sr = sox_io_backend.load(fileobj, format=format_) + + assert sr == sample_rate + self.assertEqual(expected, found) + + @parameterized.expand([ + ('wav', None), + ('mp3', 128), + ('mp3', 320), + ('flac', 0), + ('flac', 5), + ('flac', 8), + ('vorbis', -1), + ('vorbis', 10), + ('amb', None), + ]) + def test_bytesio_clogged(self, ext, compression): + """Loading audio via clogged file object returns the same result as via file path. + + This test case validates the case where fileobject returns shorter bytes than requeted. + """ + sample_rate = 16000 + format_ = ext if ext in ['mp3'] else None + path = self.get_temp_path(f'test.{ext}') + + sox_utils.gen_audio_file( + path, sample_rate, num_channels=2, + compression=compression) + expected, _ = sox_io_backend.load(path) + + with open(path, 'rb') as file_: + fileobj = CloggedFileObj(io.BytesIO(file_.read())) + found, sr = sox_io_backend.load(fileobj, format=format_) + + assert sr == sample_rate + self.assertEqual(expected, found) + + @parameterized.expand([ + ('wav', None), + ('mp3', 128), + ('mp3', 320), + ('flac', 0), + ('flac', 5), + ('flac', 8), + ('vorbis', -1), + ('vorbis', 10), + ('amb', None), + ]) + def test_bytesio_tiny(self, ext, compression): + """Loading very small audio via file object returns the same result as via file path. + """ + sample_rate = 16000 + format_ = ext if ext in ['mp3'] else None + path = self.get_temp_path(f'test.{ext}') + + sox_utils.gen_audio_file( + path, sample_rate, num_channels=2, + compression=compression, duration=1 / 1600) + expected, _ = sox_io_backend.load(path) + + with open(path, 'rb') as file_: + fileobj = io.BytesIO(file_.read()) + found, sr = sox_io_backend.load(fileobj, format=format_) + + assert sr == sample_rate + self.assertEqual(expected, found) + + @parameterized.expand([ + ('wav', None), + ('mp3', 128), + ('mp3', 320), + ('flac', 0), + ('flac', 5), + ('flac', 8), + ('vorbis', -1), + ('vorbis', 10), + ('amb', None), + ]) + def test_tarfile(self, ext, compression): + """Loading compressed audio via file-like object returns the same result as via file path.""" + sample_rate = 16000 + format_ = ext if ext in ['mp3'] else None + audio_file = f'test.{ext}' + audio_path = self.get_temp_path(audio_file) + archive_path = self.get_temp_path('archive.tar.gz') + + sox_utils.gen_audio_file( + audio_path, sample_rate, num_channels=2, + compression=compression) + expected, _ = sox_io_backend.load(audio_path) + + with tarfile.TarFile(archive_path, 'w') as tarobj: + tarobj.add(audio_path, arcname=audio_file) + with tarfile.TarFile(archive_path, 'r') as tarobj: + fileobj = tarobj.extractfile(audio_file) + found, sr = sox_io_backend.load(fileobj, format=format_) + + assert sr == sample_rate + self.assertEqual(expected, found) + + +@skipIfNoSox +@skipIfNoExec('sox') +@skipIfNoModule("requests") +class TestFileObjectHttp(HttpServerMixin, PytorchTestCase): + @parameterized.expand([ + ('wav', None), + ('mp3', 128), + ('mp3', 320), + ('flac', 0), + ('flac', 5), + ('flac', 8), + ('vorbis', -1), + ('vorbis', 10), + ('amb', None), + ]) + def test_requests(self, ext, compression): + sample_rate = 16000 + format_ = ext if ext in ['mp3'] else None + audio_file = f'test.{ext}' + audio_path = self.get_temp_path(audio_file) + + sox_utils.gen_audio_file( + audio_path, sample_rate, num_channels=2, compression=compression) + expected, _ = sox_io_backend.load(audio_path) + + url = self.get_url(audio_file) + with requests.get(url, stream=True) as resp: + found, sr = sox_io_backend.load(resp.raw, format=format_) + + assert sr == sample_rate + self.assertEqual(expected, found) + + @parameterized.expand(list(itertools.product( + [0, 1, 10, 100, 1000], + [-1, 1, 10, 100, 1000], + )), name_func=name_func) + def test_frame(self, frame_offset, num_frames): + """num_frames and frame_offset correctly specify the region of data""" + sample_rate = 8000 + audio_file = 'test.wav' + audio_path = self.get_temp_path(audio_file) + + original = get_wav_data('float32', num_channels=2) + save_wav(audio_path, original, sample_rate) + frame_end = None if num_frames == -1 else frame_offset + num_frames + expected = original[:, frame_offset:frame_end] + + url = self.get_url(audio_file) + with requests.get(url, stream=True) as resp: + found, sr = sox_io_backend.load(resp.raw, frame_offset, num_frames) + + assert sr == sample_rate + self.assertEqual(expected, found) + + +@skipIfNoSox +class TestLoadNoSuchFile(PytorchTestCase): + def test_load_fail(self): + """ + When attempted to load a non-existing file, error message must contain the file path. + """ + path = "non_existing_audio.wav" + with self.assertRaisesRegex(RuntimeError, "^Error loading audio file: failed to open file {0}$".format(path)): + sox_io_backend.load(path) diff --git a/test/torchaudio_unittest/backend/sox_io/roundtrip_test.py b/test/torchaudio_unittest/backend/sox_io/roundtrip_test.py new file mode 100644 index 0000000000000000000000000000000000000000..32c920eea1ecaa1d82a426dbc4ecebbcc79d3f5f --- /dev/null +++ b/test/torchaudio_unittest/backend/sox_io/roundtrip_test.py @@ -0,0 +1,54 @@ +import itertools + +from torchaudio.backend import sox_io_backend +from parameterized import parameterized + +from torchaudio_unittest.common_utils import ( + TempDirMixin, + PytorchTestCase, + skipIfNoExec, + skipIfNoSox, + get_wav_data, +) +from .common import ( + name_func, + get_enc_params, +) + + +@skipIfNoExec('sox') +@skipIfNoSox +class TestRoundTripIO(TempDirMixin, PytorchTestCase): + """save/load round trip should not degrade data for lossless formats""" + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16', 'uint8'], + [8000, 16000], + [1, 2], + )), name_func=name_func) + def test_wav(self, dtype, sample_rate, num_channels): + """save/load round trip should not degrade data for wav formats""" + original = get_wav_data(dtype, num_channels, normalize=False) + enc, bps = get_enc_params(dtype) + data = original + for i in range(10): + path = self.get_temp_path(f'{i}.wav') + sox_io_backend.save(path, data, sample_rate, encoding=enc, bits_per_sample=bps) + data, sr = sox_io_backend.load(path, normalize=False) + assert sr == sample_rate + self.assertEqual(original, data) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + list(range(9)), + )), name_func=name_func) + def test_flac(self, sample_rate, num_channels, compression_level): + """save/load round trip should not degrade data for flac formats""" + original = get_wav_data('float32', num_channels) + data = original + for i in range(10): + path = self.get_temp_path(f'{i}.flac') + sox_io_backend.save(path, data, sample_rate, compression=compression_level) + data, sr = sox_io_backend.load(path) + assert sr == sample_rate + self.assertEqual(original, data) diff --git a/test/torchaudio_unittest/backend/sox_io/save_test.py b/test/torchaudio_unittest/backend/sox_io/save_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9592e78d2d0a537822b424acdc720213acf1a6f1 --- /dev/null +++ b/test/torchaudio_unittest/backend/sox_io/save_test.py @@ -0,0 +1,402 @@ +import io +import os +import unittest + +import torch +from torchaudio.backend import sox_io_backend +from parameterized import parameterized + +from torchaudio_unittest.common_utils import ( + TempDirMixin, + TorchaudioTestCase, + PytorchTestCase, + skipIfNoExec, + skipIfNoSox, + get_wav_data, + load_wav, + save_wav, + sox_utils, + nested_params, +) +from .common import ( + name_func, + get_enc_params, +) + + +def _get_sox_encoding(encoding): + encodings = { + 'PCM_F': 'floating-point', + 'PCM_S': 'signed-integer', + 'PCM_U': 'unsigned-integer', + 'ULAW': 'u-law', + 'ALAW': 'a-law', + } + return encodings.get(encoding) + + +class SaveTestBase(TempDirMixin, TorchaudioTestCase): + def assert_save_consistency( + self, + format: str, + *, + compression: float = None, + encoding: str = None, + bits_per_sample: int = None, + sample_rate: float = 8000, + num_channels: int = 2, + num_frames: float = 3 * 8000, + src_dtype: str = 'int32', + test_mode: str = "path", + ): + """`save` function produces file that is comparable with `sox` command + + To compare that the file produced by `save` function agains the file produced by + the equivalent `sox` command, we need to load both files. + But there are many formats that cannot be opened with common Python modules (like + SciPy). + So we use `sox` command to prepare the original data and convert the saved files + into a format that SciPy can read (PCM wav). + The following diagram illustrates this process. The difference is 2.1. and 3.1. + + This assumes that + - loading data with SciPy preserves the data well. + - converting the resulting files into WAV format with `sox` preserve the data well. + + x + | 1. Generate source wav file with SciPy + | + v + -------------- wav ---------------- + | | + | 2.1. load with scipy | 3.1. Convert to the target + | then save it into the target | format depth with sox + | format with torchaudio | + v v + target format target format + | | + | 2.2. Convert to wav with sox | 3.2. Convert to wav with sox + | | + v v + wav wav + | | + | 2.3. load with scipy | 3.3. load with scipy + | | + v v + tensor -------> compare <--------- tensor + + """ + cmp_encoding = 'floating-point' + cmp_bit_depth = 32 + + src_path = self.get_temp_path('1.source.wav') + tgt_path = self.get_temp_path(f'2.1.torchaudio.{format}') + tst_path = self.get_temp_path('2.2.result.wav') + sox_path = self.get_temp_path(f'3.1.sox.{format}') + ref_path = self.get_temp_path('3.2.ref.wav') + + # 1. Generate original wav + data = get_wav_data(src_dtype, num_channels, normalize=False, num_frames=num_frames) + save_wav(src_path, data, sample_rate) + + # 2.1. Convert the original wav to target format with torchaudio + data = load_wav(src_path, normalize=False)[0] + if test_mode == "path": + sox_io_backend.save( + tgt_path, data, sample_rate, + compression=compression, encoding=encoding, bits_per_sample=bits_per_sample) + elif test_mode == "fileobj": + with open(tgt_path, 'bw') as file_: + sox_io_backend.save( + file_, data, sample_rate, + format=format, compression=compression, + encoding=encoding, bits_per_sample=bits_per_sample) + elif test_mode == "bytesio": + file_ = io.BytesIO() + sox_io_backend.save( + file_, data, sample_rate, + format=format, compression=compression, + encoding=encoding, bits_per_sample=bits_per_sample) + file_.seek(0) + with open(tgt_path, 'bw') as f: + f.write(file_.read()) + else: + raise ValueError(f"Unexpected test mode: {test_mode}") + # 2.2. Convert the target format to wav with sox + sox_utils.convert_audio_file( + tgt_path, tst_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth) + # 2.3. Load with SciPy + found = load_wav(tst_path, normalize=False)[0] + + # 3.1. Convert the original wav to target format with sox + sox_encoding = _get_sox_encoding(encoding) + sox_utils.convert_audio_file( + src_path, sox_path, + compression=compression, encoding=sox_encoding, bit_depth=bits_per_sample) + # 3.2. Convert the target format to wav with sox + sox_utils.convert_audio_file( + sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth) + # 3.3. Load with SciPy + expected = load_wav(ref_path, normalize=False)[0] + + self.assertEqual(found, expected) + + +@skipIfNoExec('sox') +@skipIfNoSox +class SaveTest(SaveTestBase): + @nested_params( + ["path", "fileobj", "bytesio"], + [ + ('PCM_U', 8), + ('PCM_S', 16), + ('PCM_S', 32), + ('PCM_F', 32), + ('PCM_F', 64), + ('ULAW', 8), + ('ALAW', 8), + ], + ) + def test_save_wav(self, test_mode, enc_params): + encoding, bits_per_sample = enc_params + self.assert_save_consistency( + "wav", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode) + + @nested_params( + ["path", "fileobj", "bytesio"], + [ + ('float32', ), + ('int32', ), + ('int16', ), + ('uint8', ), + ], + ) + def test_save_wav_dtype(self, test_mode, params): + dtype, = params + self.assert_save_consistency( + "wav", src_dtype=dtype, test_mode=test_mode) + + @nested_params( + ["path", "fileobj", "bytesio"], + [ + None, + -4.2, + -0.2, + 0, + 0.2, + 96, + 128, + 160, + 192, + 224, + 256, + 320, + ], + ) + def test_save_mp3(self, test_mode, bit_rate): + if test_mode in ["fileobj", "bytesio"]: + if bit_rate is not None and bit_rate < 1: + raise unittest.SkipTest( + "mp3 format with variable bit rate is known to " + "not yield the exact same result as sox command.") + self.assert_save_consistency( + "mp3", compression=bit_rate, test_mode=test_mode) + + @nested_params( + ["path", "fileobj", "bytesio"], + [8, 16, 24], + [ + None, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + ], + ) + def test_save_flac(self, test_mode, bits_per_sample, compression_level): + self.assert_save_consistency( + "flac", compression=compression_level, + bits_per_sample=bits_per_sample, test_mode=test_mode) + + @nested_params( + ["path", "fileobj", "bytesio"], + ) + def test_save_htk(self, test_mode): + self.assert_save_consistency("htk", test_mode=test_mode, num_channels=1) + + @nested_params( + ["path", "fileobj", "bytesio"], + [ + None, + -1, + 0, + 1, + 2, + 3, + 3.6, + 5, + 10, + ], + ) + def test_save_vorbis(self, test_mode, quality_level): + self.assert_save_consistency( + "vorbis", compression=quality_level, test_mode=test_mode) + + @nested_params( + ["path", "fileobj", "bytesio"], + [ + ('PCM_S', 8, ), + ('PCM_S', 16, ), + ('PCM_S', 24, ), + ('PCM_S', 32, ), + ('ULAW', 8), + ('ALAW', 8), + ('ALAW', 16), + ('ALAW', 24), + ('ALAW', 32), + ], + ) + def test_save_sphere(self, test_mode, enc_params): + encoding, bits_per_sample = enc_params + self.assert_save_consistency( + "sph", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode) + + @nested_params( + ["path", "fileobj", "bytesio"], + [ + ('PCM_U', 8, ), + ('PCM_S', 16, ), + ('PCM_S', 24, ), + ('PCM_S', 32, ), + ('PCM_F', 32, ), + ('PCM_F', 64, ), + ('ULAW', 8, ), + ('ALAW', 8, ), + ], + ) + def test_save_amb(self, test_mode, enc_params): + encoding, bits_per_sample = enc_params + self.assert_save_consistency( + "amb", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode) + + @nested_params( + ["path", "fileobj", "bytesio"], + [ + None, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + ], + ) + def test_save_amr_nb(self, test_mode, bit_rate): + self.assert_save_consistency( + "amr-nb", compression=bit_rate, num_channels=1, test_mode=test_mode) + + @nested_params( + ["path", "fileobj", "bytesio"], + ) + def test_save_gsm(self, test_mode): + self.assert_save_consistency( + "gsm", num_channels=1, test_mode=test_mode) + with self.assertRaises( + RuntimeError, msg="gsm format only supports single channel audio."): + self.assert_save_consistency( + "gsm", num_channels=2, test_mode=test_mode) + with self.assertRaises( + RuntimeError, msg="gsm format only supports a sampling rate of 8kHz."): + self.assert_save_consistency( + "gsm", sample_rate=16000, test_mode=test_mode) + + @parameterized.expand([ + ("wav", "PCM_S", 16), + ("mp3", ), + ("flac", ), + ("vorbis", ), + ("sph", "PCM_S", 16), + ("amr-nb", ), + ("amb", "PCM_S", 16), + ], name_func=name_func) + def test_save_large(self, format, encoding=None, bits_per_sample=None): + """`sox_io_backend.save` can save large files.""" + sample_rate = 8000 + one_hour = 60 * 60 * sample_rate + self.assert_save_consistency( + format, num_channels=1, sample_rate=8000, num_frames=one_hour, + encoding=encoding, bits_per_sample=bits_per_sample) + + @parameterized.expand([ + (32, ), + (64, ), + (128, ), + (256, ), + ], name_func=name_func) + def test_save_multi_channels(self, num_channels): + """`sox_io_backend.save` can save audio with many channels""" + self.assert_save_consistency( + "wav", encoding="PCM_S", bits_per_sample=16, + num_channels=num_channels) + + +@skipIfNoExec('sox') +@skipIfNoSox +class TestSaveParams(TempDirMixin, PytorchTestCase): + """Test the correctness of optional parameters of `sox_io_backend.save`""" + @parameterized.expand([(True, ), (False, )], name_func=name_func) + def test_save_channels_first(self, channels_first): + """channels_first swaps axes""" + path = self.get_temp_path('data.wav') + data = get_wav_data( + 'int16', 2, channels_first=channels_first, normalize=False) + sox_io_backend.save( + path, data, 8000, channels_first=channels_first) + found = load_wav(path, normalize=False)[0] + expected = data if channels_first else data.transpose(1, 0) + self.assertEqual(found, expected) + + @parameterized.expand([ + 'float32', 'int32', 'int16', 'uint8' + ], name_func=name_func) + def test_save_noncontiguous(self, dtype): + """Noncontiguous tensors are saved correctly""" + path = self.get_temp_path('data.wav') + enc, bps = get_enc_params(dtype) + expected = get_wav_data(dtype, 4, normalize=False)[::2, ::2] + assert not expected.is_contiguous() + sox_io_backend.save( + path, expected, 8000, encoding=enc, bits_per_sample=bps) + found = load_wav(path, normalize=False)[0] + self.assertEqual(found, expected) + + @parameterized.expand([ + 'float32', 'int32', 'int16', 'uint8', + ]) + def test_save_tensor_preserve(self, dtype): + """save function should not alter Tensor""" + path = self.get_temp_path('data.wav') + expected = get_wav_data(dtype, 4, normalize=False)[::2, ::2] + + data = expected.clone() + sox_io_backend.save(path, data, 8000) + + self.assertEqual(data, expected) + + +@skipIfNoSox +class TestSaveNonExistingDirectory(PytorchTestCase): + def test_save_fail(self): + """ + When attempted to save into a non-existing dir, error message must contain the file path. + """ + path = os.path.join("non_existing_directory", "foo.wav") + with self.assertRaisesRegex(RuntimeError, "^Error saving audio file: failed to open file {0}$".format(path)): + sox_io_backend.save(path, torch.zeros(1, 1), 8000) diff --git a/test/torchaudio_unittest/backend/sox_io/smoke_test.py b/test/torchaudio_unittest/backend/sox_io/smoke_test.py new file mode 100644 index 0000000000000000000000000000000000000000..656cfd9fbeb18cd6654db6435776576db1409224 --- /dev/null +++ b/test/torchaudio_unittest/backend/sox_io/smoke_test.py @@ -0,0 +1,155 @@ +import io +import itertools +import unittest + +from torchaudio.utils import sox_utils +from torchaudio.backend import sox_io_backend +from torchaudio._internal.module_utils import is_sox_available +from parameterized import parameterized + +from torchaudio_unittest.common_utils import ( + TempDirMixin, + TorchaudioTestCase, + skipIfNoSox, + get_wav_data, +) +from .common import name_func + + +skipIfNoMP3 = unittest.skipIf( + not is_sox_available() or + 'mp3' not in sox_utils.list_read_formats() or + 'mp3' not in sox_utils.list_write_formats(), + '"sox_io" backend does not support MP3') + + +@skipIfNoSox +class SmokeTest(TempDirMixin, TorchaudioTestCase): + """Run smoke test on various audio format + + The purpose of this test suite is to verify that sox_io_backend functionalities do not exhibit + abnormal behaviors. + + This test suite should be able to run without any additional tools (such as sox command), + however without such tools, the correctness of each function cannot be verified. + """ + def run_smoke_test(self, ext, sample_rate, num_channels, *, compression=None, dtype='float32'): + duration = 1 + num_frames = sample_rate * duration + path = self.get_temp_path(f'test.{ext}') + original = get_wav_data(dtype, num_channels, normalize=False, num_frames=num_frames) + + # 1. run save + sox_io_backend.save(path, original, sample_rate, compression=compression) + # 2. run info + info = sox_io_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_channels == num_channels + # 3. run load + loaded, sr = sox_io_backend.load(path, normalize=False) + assert sr == sample_rate + assert loaded.shape[0] == num_channels + + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16', 'uint8'], + [8000, 16000], + [1, 2], + )), name_func=name_func) + def test_wav(self, dtype, sample_rate, num_channels): + """Run smoke test on wav format""" + self.run_smoke_test('wav', sample_rate, num_channels, dtype=dtype) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + [-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320], + ))) + @skipIfNoMP3 + def test_mp3(self, sample_rate, num_channels, bit_rate): + """Run smoke test on mp3 format""" + self.run_smoke_test('mp3', sample_rate, num_channels, compression=bit_rate) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + [-1, 0, 1, 2, 3, 3.6, 5, 10], + ))) + def test_vorbis(self, sample_rate, num_channels, quality_level): + """Run smoke test on vorbis format""" + self.run_smoke_test('vorbis', sample_rate, num_channels, compression=quality_level) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + list(range(9)), + )), name_func=name_func) + def test_flac(self, sample_rate, num_channels, compression_level): + """Run smoke test on flac format""" + self.run_smoke_test('flac', sample_rate, num_channels, compression=compression_level) + + +@skipIfNoSox +class SmokeTestFileObj(TorchaudioTestCase): + """Run smoke test on various audio format + + The purpose of this test suite is to verify that sox_io_backend functionalities do not exhibit + abnormal behaviors. + + This test suite should be able to run without any additional tools (such as sox command), + however without such tools, the correctness of each function cannot be verified. + """ + def run_smoke_test(self, ext, sample_rate, num_channels, *, compression=None, dtype='float32'): + duration = 1 + num_frames = sample_rate * duration + original = get_wav_data(dtype, num_channels, normalize=False, num_frames=num_frames) + + fileobj = io.BytesIO() + # 1. run save + sox_io_backend.save(fileobj, original, sample_rate, compression=compression, format=ext) + # 2. run info + fileobj.seek(0) + info = sox_io_backend.info(fileobj, format=ext) + assert info.sample_rate == sample_rate + assert info.num_channels == num_channels + # 3. run load + fileobj.seek(0) + loaded, sr = sox_io_backend.load(fileobj, normalize=False, format=ext) + assert sr == sample_rate + assert loaded.shape[0] == num_channels + + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16', 'uint8'], + [8000, 16000], + [1, 2], + )), name_func=name_func) + def test_wav(self, dtype, sample_rate, num_channels): + """Run smoke test on wav format""" + self.run_smoke_test('wav', sample_rate, num_channels, dtype=dtype) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + [-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320], + ))) + @skipIfNoMP3 + def test_mp3(self, sample_rate, num_channels, bit_rate): + """Run smoke test on mp3 format""" + self.run_smoke_test('mp3', sample_rate, num_channels, compression=bit_rate) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + [-1, 0, 1, 2, 3, 3.6, 5, 10], + ))) + def test_vorbis(self, sample_rate, num_channels, quality_level): + """Run smoke test on vorbis format""" + self.run_smoke_test('vorbis', sample_rate, num_channels, compression=quality_level) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + list(range(9)), + )), name_func=name_func) + def test_flac(self, sample_rate, num_channels, compression_level): + """Run smoke test on flac format""" + self.run_smoke_test('flac', sample_rate, num_channels, compression=compression_level) diff --git a/test/torchaudio_unittest/backend/sox_io/torchscript_test.py b/test/torchaudio_unittest/backend/sox_io/torchscript_test.py new file mode 100644 index 0000000000000000000000000000000000000000..122f4bc0d0c0dfc23eb5901ec3162cd3ba1521fa --- /dev/null +++ b/test/torchaudio_unittest/backend/sox_io/torchscript_test.py @@ -0,0 +1,148 @@ +import itertools +from typing import Optional + +import torch +import torchaudio +from parameterized import parameterized + +from torchaudio_unittest.common_utils import ( + TempDirMixin, + TorchaudioTestCase, + skipIfNoExec, + skipIfNoSox, + get_wav_data, + save_wav, + load_wav, + sox_utils, + torch_script, +) +from .common import ( + name_func, + get_enc_params, +) + + +def py_info_func(filepath: str) -> torchaudio.backend.sox_io_backend.AudioMetaData: + return torchaudio.info(filepath) + + +def py_load_func(filepath: str, normalize: bool, channels_first: bool): + return torchaudio.load( + filepath, normalize=normalize, channels_first=channels_first) + + +def py_save_func( + filepath: str, + tensor: torch.Tensor, + sample_rate: int, + channels_first: bool = True, + compression: Optional[float] = None, + encoding: Optional[str] = None, + bits_per_sample: Optional[int] = None, +): + torchaudio.save( + filepath, tensor, sample_rate, channels_first, + compression, None, encoding, bits_per_sample) + + +@skipIfNoExec('sox') +@skipIfNoSox +class SoxIO(TempDirMixin, TorchaudioTestCase): + """TorchScript-ability Test suite for `sox_io_backend`""" + backend = 'sox_io' + + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16', 'uint8'], + [8000, 16000], + [1, 2], + )), name_func=name_func) + def test_info_wav(self, dtype, sample_rate, num_channels): + """`sox_io_backend.info` is torchscript-able and returns the same result""" + audio_path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav') + data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate) + save_wav(audio_path, data, sample_rate) + + ts_info_func = torch_script(py_info_func) + + py_info = py_info_func(audio_path) + ts_info = ts_info_func(audio_path) + + assert py_info.sample_rate == ts_info.sample_rate + assert py_info.num_frames == ts_info.num_frames + assert py_info.num_channels == ts_info.num_channels + + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16', 'uint8'], + [8000, 16000], + [1, 2], + [False, True], + [False, True], + )), name_func=name_func) + def test_load_wav(self, dtype, sample_rate, num_channels, normalize, channels_first): + """`sox_io_backend.load` is torchscript-able and returns the same result""" + audio_path = self.get_temp_path(f'test_load_{dtype}_{sample_rate}_{num_channels}_{normalize}.wav') + data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate) + save_wav(audio_path, data, sample_rate) + + ts_load_func = torch_script(py_load_func) + + py_data, py_sr = py_load_func( + audio_path, normalize=normalize, channels_first=channels_first) + ts_data, ts_sr = ts_load_func( + audio_path, normalize=normalize, channels_first=channels_first) + + self.assertEqual(py_sr, ts_sr) + self.assertEqual(py_data, ts_data) + + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16', 'uint8'], + [8000, 16000], + [1, 2], + )), name_func=name_func) + def test_save_wav(self, dtype, sample_rate, num_channels): + ts_save_func = torch_script(py_save_func) + + expected = get_wav_data(dtype, num_channels, normalize=False) + py_path = self.get_temp_path(f'test_save_py_{dtype}_{sample_rate}_{num_channels}.wav') + ts_path = self.get_temp_path(f'test_save_ts_{dtype}_{sample_rate}_{num_channels}.wav') + enc, bps = get_enc_params(dtype) + + py_save_func(py_path, expected, sample_rate, True, None, enc, bps) + ts_save_func(ts_path, expected, sample_rate, True, None, enc, bps) + + py_data, py_sr = load_wav(py_path, normalize=False) + ts_data, ts_sr = load_wav(ts_path, normalize=False) + + self.assertEqual(sample_rate, py_sr) + self.assertEqual(sample_rate, ts_sr) + self.assertEqual(expected, py_data) + self.assertEqual(expected, ts_data) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + list(range(9)), + )), name_func=name_func) + def test_save_flac(self, sample_rate, num_channels, compression_level): + ts_save_func = torch_script(py_save_func) + + expected = get_wav_data('float32', num_channels) + py_path = self.get_temp_path(f'test_save_py_{sample_rate}_{num_channels}_{compression_level}.flac') + ts_path = self.get_temp_path(f'test_save_ts_{sample_rate}_{num_channels}_{compression_level}.flac') + + py_save_func(py_path, expected, sample_rate, True, compression_level, None, None) + ts_save_func(ts_path, expected, sample_rate, True, compression_level, None, None) + + # converting to 32 bit because flac file has 24 bit depth which scipy cannot handle. + py_path_wav = f'{py_path}.wav' + ts_path_wav = f'{ts_path}.wav' + sox_utils.convert_audio_file(py_path, py_path_wav, bit_depth=32) + sox_utils.convert_audio_file(ts_path, ts_path_wav, bit_depth=32) + + py_data, py_sr = load_wav(py_path_wav, normalize=True) + ts_data, ts_sr = load_wav(ts_path_wav, normalize=True) + + self.assertEqual(sample_rate, py_sr) + self.assertEqual(sample_rate, ts_sr) + self.assertEqual(expected, py_data) + self.assertEqual(expected, ts_data) diff --git a/test/torchaudio_unittest/backend/utils_test.py b/test/torchaudio_unittest/backend/utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c89f1e97884376733b563dbf0c163955e17d31a7 --- /dev/null +++ b/test/torchaudio_unittest/backend/utils_test.py @@ -0,0 +1,36 @@ +import torchaudio + +from torchaudio_unittest import common_utils + + +class BackendSwitchMixin: + """Test set/get_audio_backend works""" + backend = None + backend_module = None + + def test_switch(self): + torchaudio.set_audio_backend(self.backend) + if self.backend is None: + assert torchaudio.get_audio_backend() is None + else: + assert torchaudio.get_audio_backend() == self.backend + assert torchaudio.load == self.backend_module.load + assert torchaudio.save == self.backend_module.save + assert torchaudio.info == self.backend_module.info + + +class TestBackendSwitch_NoBackend(BackendSwitchMixin, common_utils.TorchaudioTestCase): + backend = None + backend_module = torchaudio.backend.no_backend + + +@common_utils.skipIfNoSox +class TestBackendSwitch_SoXIO(BackendSwitchMixin, common_utils.TorchaudioTestCase): + backend = 'sox_io' + backend_module = torchaudio.backend.sox_io_backend + + +@common_utils.skipIfNoModule('soundfile') +class TestBackendSwitch_soundfile(BackendSwitchMixin, common_utils.TorchaudioTestCase): + backend = 'soundfile' + backend_module = torchaudio.backend.soundfile_backend diff --git a/test/torchaudio_unittest/common_utils/__init__.py b/test/torchaudio_unittest/common_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..48d53ba89f02d4d8feb6cceca826338c199cae06 --- /dev/null +++ b/test/torchaudio_unittest/common_utils/__init__.py @@ -0,0 +1,63 @@ +from .data_utils import ( + get_asset_path, + get_whitenoise, + get_sinusoid, + get_spectrogram, +) +from .backend_utils import ( + set_audio_backend, +) +from .case_utils import ( + TempDirMixin, + HttpServerMixin, + TestBaseMixin, + PytorchTestCase, + TorchaudioTestCase, + skipIfNoCuda, + skipIfNoExec, + skipIfNoModule, + skipIfNoKaldi, + skipIfNoSox, + skipIfRocm, + skipIfNoQengine, +) +from .wav_utils import ( + get_wav_data, + normalize_wav, + load_wav, + save_wav, +) +from .parameterized_utils import ( + load_params, + nested_params +) +from .func_utils import torch_script + + +__all__ = [ + 'get_asset_path', + 'get_whitenoise', + 'get_sinusoid', + 'get_spectrogram', + 'set_audio_backend', + 'TempDirMixin', + 'HttpServerMixin', + 'TestBaseMixin', + 'PytorchTestCase', + 'TorchaudioTestCase', + 'skipIfNoCuda', + 'skipIfNoExec', + 'skipIfNoModule', + 'skipIfNoKaldi', + 'skipIfNoSox', + 'skipIfNoSoxBackend', + 'skipIfRocm', + 'skipIfNoQengine', + 'get_wav_data', + 'normalize_wav', + 'load_wav', + 'save_wav', + 'load_params', + 'nested_params', + 'torch_script', +] diff --git a/test/torchaudio_unittest/common_utils/backend_utils.py b/test/torchaudio_unittest/common_utils/backend_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..84dd73ed2e01ce8ab3a4d99becec24db87cf9931 --- /dev/null +++ b/test/torchaudio_unittest/common_utils/backend_utils.py @@ -0,0 +1,21 @@ +import unittest + +import torchaudio + + +def set_audio_backend(backend): + """Allow additional backend value, 'default'""" + backends = torchaudio.list_audio_backends() + if backend == 'soundfile': + be = 'soundfile' + elif backend == 'default': + if 'sox_io' in backends: + be = 'sox_io' + elif 'soundfile' in backends: + be = 'soundfile' + else: + raise unittest.SkipTest('No default backend available') + else: + be = backend + + torchaudio.set_audio_backend(be) diff --git a/test/torchaudio_unittest/common_utils/case_utils.py b/test/torchaudio_unittest/common_utils/case_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..140291a8a275b992c7362f90549444a4c43483e3 --- /dev/null +++ b/test/torchaudio_unittest/common_utils/case_utils.py @@ -0,0 +1,123 @@ +import shutil +import os.path +import subprocess +import tempfile +import time +import unittest + +import torch +from torch.testing._internal.common_utils import TestCase as PytorchTestCase +from torchaudio._internal.module_utils import ( + is_module_available, + is_sox_available, + is_kaldi_available +) + +from .backend_utils import set_audio_backend + + +class TempDirMixin: + """Mixin to provide easy access to temp dir""" + temp_dir_ = None + + @classmethod + def get_base_temp_dir(cls): + # If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory. + # this is handy for debugging. + key = 'TORCHAUDIO_TEST_TEMP_DIR' + if key in os.environ: + return os.environ[key] + if cls.temp_dir_ is None: + cls.temp_dir_ = tempfile.TemporaryDirectory() + return cls.temp_dir_.name + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + if cls.temp_dir_ is not None: + cls.temp_dir_.cleanup() + cls.temp_dir_ = None + + def get_temp_path(self, *paths): + temp_dir = os.path.join(self.get_base_temp_dir(), self.id()) + path = os.path.join(temp_dir, *paths) + os.makedirs(os.path.dirname(path), exist_ok=True) + return path + + +class HttpServerMixin(TempDirMixin): + """Mixin that serves temporary directory as web server + + This class creates temporary directory and serve the directory as HTTP service. + The server is up through the execution of all the test suite defined under the subclass. + """ + _proc = None + _port = 8000 + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._proc = subprocess.Popen( + ['python', '-m', 'http.server', f'{cls._port}'], + cwd=cls.get_base_temp_dir(), + stderr=subprocess.DEVNULL) # Disable server-side error log because it is confusing + time.sleep(2.0) + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + cls._proc.kill() + + def get_url(self, *route): + return f'http://localhost:{self._port}/{self.id()}/{"/".join(route)}' + + +class TestBaseMixin: + """Mixin to provide consistent way to define device/dtype/backend aware TestCase""" + dtype = None + device = None + backend = None + + def setUp(self): + super().setUp() + set_audio_backend(self.backend) + + @property + def complex_dtype(self): + if self.dtype in ['float32', 'float', torch.float, torch.float32]: + return torch.cfloat + if self.dtype in ['float64', 'double', torch.double, torch.float64]: + return torch.cdouble + raise ValueError(f'No corresponding complex dtype for {self.dtype}') + + +class TorchaudioTestCase(TestBaseMixin, PytorchTestCase): + pass + + +def skipIfNoExec(cmd): + return unittest.skipIf(shutil.which(cmd) is None, f'`{cmd}` is not available') + + +def skipIfNoModule(module, display_name=None): + display_name = display_name or module + return unittest.skipIf(not is_module_available(module), f'"{display_name}" is not available') + + +def skipIfNoCuda(test_item): + if torch.cuda.is_available(): + return test_item + force_cuda_test = os.environ.get('TORCHAUDIO_TEST_FORCE_CUDA', '0') + if force_cuda_test not in ['0', '1']: + raise ValueError('"TORCHAUDIO_TEST_FORCE_CUDA" must be either "0" or "1".') + if force_cuda_test == '1': + raise RuntimeError('"TORCHAUDIO_TEST_FORCE_CUDA" is set but CUDA is not available.') + return unittest.skip('CUDA is not available.')(test_item) +skipIfNoSox = unittest.skipIf(not is_sox_available(), reason='Sox not available') +skipIfNoKaldi = unittest.skipIf(not is_kaldi_available(), reason='Kaldi not available') +skipIfRocm = unittest.skipIf(os.getenv('TORCHAUDIO_TEST_WITH_ROCM', '0') == '1', + reason="test doesn't currently work on the ROCm stack") +skipIfNoQengine = unittest.skipIf( + 'fbgemm' not in torch.backends.quantized.supported_engines, + reason="`fbgemm` is not available." +) diff --git a/test/torchaudio_unittest/common_utils/data_utils.py b/test/torchaudio_unittest/common_utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0c97cc1ce3d1196f5335147a6c278e5e497003d4 --- /dev/null +++ b/test/torchaudio_unittest/common_utils/data_utils.py @@ -0,0 +1,155 @@ +import os.path +from typing import Union, Optional + +import torch + + +_TEST_DIR_PATH = os.path.realpath( + os.path.join(os.path.dirname(__file__), '..')) + + +def get_asset_path(*paths): + """Return full path of a test asset""" + return os.path.join(_TEST_DIR_PATH, 'assets', *paths) + + +def convert_tensor_encoding( + tensor: torch.tensor, + dtype: torch.dtype, +): + """Convert input tensor with values between -1 and 1 to integer encoding + Args: + tensor: input tensor, assumed between -1 and 1 + dtype: desired output tensor dtype + Returns: + Tensor: shape of (n_channels, sample_rate * duration) + """ + if dtype == torch.int32: + tensor *= (tensor > 0) * 2147483647 + (tensor < 0) * 2147483648 + if dtype == torch.int16: + tensor *= (tensor > 0) * 32767 + (tensor < 0) * 32768 + if dtype == torch.uint8: + tensor *= (tensor > 0) * 127 + (tensor < 0) * 128 + tensor += 128 + tensor = tensor.to(dtype) + return tensor + + +def get_whitenoise( + *, + sample_rate: int = 16000, + duration: float = 1, # seconds + n_channels: int = 1, + seed: int = 0, + dtype: Union[str, torch.dtype] = "float32", + device: Union[str, torch.device] = "cpu", + channels_first=True, + scale_factor: float = 1, +): + """Generate pseudo audio data with whitenoise + Args: + sample_rate: Sampling rate + duration: Length of the resulting Tensor in seconds. + n_channels: Number of channels + seed: Seed value used for random number generation. + Note that this function does not modify global random generator state. + dtype: Torch dtype + device: device + channels_first: whether first dimension is n_channels + scale_factor: scale the Tensor before clamping and quantization + Returns: + Tensor: shape of (n_channels, sample_rate * duration) + """ + if isinstance(dtype, str): + dtype = getattr(torch, dtype) + if dtype not in [torch.float64, torch.float32, torch.int32, torch.int16, torch.uint8]: + raise NotImplementedError(f'dtype {dtype} is not supported.') + # According to the doc, folking rng on all CUDA devices is slow when there are many CUDA devices, + # so we only fork on CPU, generate values and move the data to the given device + with torch.random.fork_rng([]): + torch.random.manual_seed(seed) + tensor = torch.randn([n_channels, int(sample_rate * duration)], + dtype=torch.float32, device='cpu') + tensor /= 2.0 + tensor *= scale_factor + tensor.clamp_(-1.0, 1.0) + if not channels_first: + tensor = tensor.t() + + tensor = tensor.to(device) + + return convert_tensor_encoding(tensor, dtype) + + +def get_sinusoid( + *, + frequency: float = 300, + sample_rate: int = 16000, + duration: float = 1, # seconds + n_channels: int = 1, + dtype: Union[str, torch.dtype] = "float32", + device: Union[str, torch.device] = "cpu", + channels_first: bool = True, +): + """Generate pseudo audio data with sine wave. + + Args: + frequency: Frequency of sine wave + sample_rate: Sampling rate + duration: Length of the resulting Tensor in seconds. + n_channels: Number of channels + dtype: Torch dtype + device: device + + Returns: + Tensor: shape of (n_channels, sample_rate * duration) + """ + if isinstance(dtype, str): + dtype = getattr(torch, dtype) + pie2 = 2 * 3.141592653589793 + end = pie2 * frequency * duration + theta = torch.linspace(0, end, int(sample_rate * duration), dtype=torch.float32, device=device) + tensor = torch.sin(theta, out=None).repeat([n_channels, 1]) + if not channels_first: + tensor = tensor.t() + return convert_tensor_encoding(tensor, dtype) + + +def get_spectrogram( + waveform, + *, + n_fft: int = 2048, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: Optional[torch.Tensor] = None, + center: bool = True, + pad_mode: str = 'reflect', + power: Optional[float] = None, +): + """Generate a spectrogram of the given Tensor + + Args: + n_fft: The number of FFT bins. + hop_length: Stride for sliding window. default: ``n_fft // 4``. + win_length: The size of window frame and STFT filter. default: ``n_fft``. + winwdow: Window function. default: Hann window + center: Pad the input sequence if True. See ``torch.stft`` for the detail. + pad_mode: Padding method used when center is True. Default: "reflect". + power: If ``None``, raw spectrogram with complex values are returned, + otherwise the norm of the spectrogram is returned. + """ + hop_length = hop_length or n_fft // 4 + win_length = win_length or n_fft + window = torch.hann_window(win_length, device=waveform.device) if window is None else window + spec = torch.stft( + waveform, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + center=center, + window=window, + pad_mode=pad_mode, + return_complex=True) + if power is not None: + spec = spec.abs() ** power + return spec diff --git a/test/torchaudio_unittest/common_utils/func_utils.py b/test/torchaudio_unittest/common_utils/func_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f8dd5dd43be1883c000c09ee85d03df5e8e036ca --- /dev/null +++ b/test/torchaudio_unittest/common_utils/func_utils.py @@ -0,0 +1,10 @@ +import io +import torch + + +def torch_script(obj): + """TorchScript the given function or Module""" + buffer = io.BytesIO() + torch.jit.save(torch.jit.script(obj), buffer) + buffer.seek(0) + return torch.jit.load(buffer) diff --git a/test/torchaudio_unittest/common_utils/kaldi_utils.py b/test/torchaudio_unittest/common_utils/kaldi_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..71053bd003f25b60ca6c7bb6ea62d9ed37ae3511 --- /dev/null +++ b/test/torchaudio_unittest/common_utils/kaldi_utils.py @@ -0,0 +1,38 @@ +import subprocess + +import torch + + +def convert_args(**kwargs): + args = [] + for key, value in kwargs.items(): + if key == 'sample_rate': + key = 'sample_frequency' + key = '--' + key.replace('_', '-') + value = str(value).lower() if value in [True, False] else str(value) + args.append('%s=%s' % (key, value)) + return args + + +def run_kaldi(command, input_type, input_value): + """Run provided Kaldi command, pass a tensor and get the resulting tensor + + Args: + command (list of str): The command with arguments + input_type (str): 'ark' or 'scp' + input_value (Tensor for 'ark', string for 'scp'): The input to pass. + Must be a path to an audio file for 'scp'. + """ + import kaldi_io + + key = 'foo' + process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE) + if input_type == 'ark': + kaldi_io.write_mat(process.stdin, input_value.cpu().numpy(), key=key) + elif input_type == 'scp': + process.stdin.write(f'{key} {input_value}'.encode('utf8')) + else: + raise NotImplementedError('Unexpected type') + process.stdin.close() + result = dict(kaldi_io.read_mat_ark(process.stdout))['foo'] + return torch.from_numpy(result.copy()) # copy supresses some torch warning diff --git a/test/torchaudio_unittest/common_utils/parameterized_utils.py b/test/torchaudio_unittest/common_utils/parameterized_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..000dccaf120ef97bb2688bb7c796d488a1302b11 --- /dev/null +++ b/test/torchaudio_unittest/common_utils/parameterized_utils.py @@ -0,0 +1,53 @@ +import json +from itertools import product + +from parameterized import param, parameterized + +from .data_utils import get_asset_path + + +def load_params(*paths): + with open(get_asset_path(*paths), 'r') as file: + return [param(json.loads(line)) for line in file] + + +def _name_func(func, _, params): + strs = [] + for arg in params.args: + if isinstance(arg, tuple): + strs.append("_".join(str(a) for a in arg)) + else: + strs.append(str(arg)) + # sanitize the test name + name = "_".join(strs).replace(".", "_") + return f'{func.__name__}_{name}' + + +def nested_params(*params_set): + """Generate the cartesian product of the given list of parameters. + + Args: + params_set (list of parameters): Parameters. When using ``parameterized.param`` class, + all the parameters have to be specified with the class, only using kwargs. + """ + flatten = [p for params in params_set for p in params] + + # Parameters to be nested are given as list of plain objects + if all(not isinstance(p, param) for p in flatten): + args = list(product(*params_set)) + return parameterized.expand(args, name_func=_name_func) + + # Parameters to be nested are given as list of `parameterized.param` + if not all(isinstance(p, param) for p in flatten): + raise TypeError( + "When using ``parameterized.param``, " + "all the parameters have to be of the ``param`` type.") + if any(p.args for p in flatten): + raise ValueError( + "When using ``parameterized.param``, " + "all the parameters have to be provided as keyword argument." + ) + args = [param()] + for params in params_set: + args = [param(**x.kwargs, **y.kwargs) for x in args for y in params] + return parameterized.expand(args) diff --git a/test/torchaudio_unittest/common_utils/psd_utils.py b/test/torchaudio_unittest/common_utils/psd_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3ab3354672f336381ff3796a70da653b2a9d920e --- /dev/null +++ b/test/torchaudio_unittest/common_utils/psd_utils.py @@ -0,0 +1,27 @@ +from typing import Optional + +import numpy as np +import torch + + +def psd_numpy( + X: np.array, + mask: Optional[np.array], + multi_mask: bool = False, + normalize: bool = True, + eps: float = 1e-15 +) -> np.array: + X_conj = np.conj(X) + psd_X = np.einsum("...cft,...eft->...ftce", X, X_conj) + if mask is not None: + if multi_mask: + mask = mask.mean(axis=-3) + if normalize: + mask = mask / (mask.sum(axis=-1, keepdims=True) + eps) + psd = psd_X * mask[..., None, None] + else: + psd = psd_X + + psd = psd.sum(axis=-3) + + return torch.tensor(psd, dtype=torch.cdouble) diff --git a/test/torchaudio_unittest/common_utils/rnnt_utils.py b/test/torchaudio_unittest/common_utils/rnnt_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..94ea300ef7ec2f331c7e1677a99b8a64471dc4ab --- /dev/null +++ b/test/torchaudio_unittest/common_utils/rnnt_utils.py @@ -0,0 +1,603 @@ +import unittest +import random +import torch +import numpy as np +from torchaudio.functional import rnnt_loss + + +CPU_DEVICE = torch.device("cpu") + + +class _NumpyTransducer(torch.autograd.Function): + @staticmethod + def forward( + ctx, + log_probs, + logit_lengths, + target_lengths, + targets, + blank=-1, + ): + device = log_probs.device + log_probs = log_probs.cpu().data.numpy() + logit_lengths = logit_lengths.cpu().data.numpy() + target_lengths = target_lengths.cpu().data.numpy() + targets = targets.cpu().data.numpy() + + gradients, costs, _, _ = __class__.compute( + log_probs=log_probs, + logit_lengths=logit_lengths, + target_lengths=target_lengths, + targets=targets, + blank=blank, + ) + + costs = torch.FloatTensor(costs).to(device=device) + gradients = torch.FloatTensor(gradients).to(device=device) + ctx.grads = torch.autograd.Variable(gradients) + + return costs + + @staticmethod + def backward(ctx, grad_output): + grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads) + return ctx.grads.mul(grad_output), None, None, None, None, None, None, None, None + + @staticmethod + def compute_alpha_one_sequence(log_probs, targets, blank=-1): + max_T, max_U, D = log_probs.shape + alpha = np.zeros((max_T, max_U), dtype=np.float32) + for t in range(1, max_T): + alpha[t, 0] = alpha[t - 1, 0] + log_probs[t - 1, 0, blank] + + for u in range(1, max_U): + alpha[0, u] = alpha[0, u - 1] + log_probs[0, u - 1, targets[u - 1]] + + for t in range(1, max_T): + for u in range(1, max_U): + skip = alpha[t - 1, u] + log_probs[t - 1, u, blank] + emit = alpha[t, u - 1] + log_probs[t, u - 1, targets[u - 1]] + alpha[t, u] = np.logaddexp(skip, emit) + + cost = -(alpha[-1, -1] + log_probs[-1, -1, blank]) + return alpha, cost + + @staticmethod + def compute_beta_one_sequence(log_probs, targets, blank=-1): + max_T, max_U, D = log_probs.shape + beta = np.zeros((max_T, max_U), dtype=np.float32) + beta[-1, -1] = log_probs[-1, -1, blank] + + for t in reversed(range(max_T - 1)): + beta[t, -1] = beta[t + 1, -1] + log_probs[t, -1, blank] + + for u in reversed(range(max_U - 1)): + beta[-1, u] = beta[-1, u + 1] + log_probs[-1, u, targets[u]] + + for t in reversed(range(max_T - 1)): + for u in reversed(range(max_U - 1)): + skip = beta[t + 1, u] + log_probs[t, u, blank] + emit = beta[t, u + 1] + log_probs[t, u, targets[u]] + beta[t, u] = np.logaddexp(skip, emit) + + cost = -beta[0, 0] + return beta, cost + + @staticmethod + def compute_gradients_one_sequence( + log_probs, alpha, beta, targets, blank=-1 + ): + max_T, max_U, D = log_probs.shape + gradients = np.full(log_probs.shape, float("-inf")) + cost = -beta[0, 0] + + gradients[-1, -1, blank] = alpha[-1, -1] + + gradients[:-1, :, blank] = alpha[:-1, :] + beta[1:, :] + + for u, l in enumerate(targets): + gradients[:, u, l] = alpha[:, u] + beta[:, u + 1] + + gradients = -(np.exp(gradients + log_probs + cost)) + return gradients + + @staticmethod + def compute( + log_probs, + logit_lengths, + target_lengths, + targets, + blank=-1, + ): + gradients = np.zeros_like(log_probs) + B_tgt, max_T, max_U, D = log_probs.shape + B_src = logit_lengths.shape[0] + + H = int(B_tgt / B_src) + + alphas = np.zeros((B_tgt, max_T, max_U)) + betas = np.zeros((B_tgt, max_T, max_U)) + betas.fill(float("-inf")) + alphas.fill(float("-inf")) + costs = np.zeros(B_tgt) + for b_tgt in range(B_tgt): + b_src = int(b_tgt / H) + T = int(logit_lengths[b_src]) + # NOTE: see https://arxiv.org/pdf/1211.3711.pdf Section 2.1 + U = int(target_lengths[b_tgt]) + 1 + + seq_log_probs = log_probs[b_tgt, :T, :U, :] + seq_targets = targets[b_tgt, : int(target_lengths[b_tgt])] + alpha, alpha_cost = __class__.compute_alpha_one_sequence( + log_probs=seq_log_probs, targets=seq_targets, blank=blank + ) + + beta, beta_cost = __class__.compute_beta_one_sequence( + log_probs=seq_log_probs, targets=seq_targets, blank=blank + ) + + seq_gradients = __class__.compute_gradients_one_sequence( + log_probs=seq_log_probs, + alpha=alpha, + beta=beta, + targets=seq_targets, + blank=blank, + ) + np.testing.assert_almost_equal(alpha_cost, beta_cost, decimal=2) + gradients[b_tgt, :T, :U, :] = seq_gradients + costs[b_tgt] = beta_cost + alphas[b_tgt, :T, :U] = alpha + betas[b_tgt, :T, :U] = beta + + return gradients, costs, alphas, betas + + +class NumpyTransducerLoss(torch.nn.Module): + def __init__(self, blank=-1): + super().__init__() + self.blank = blank + + def forward( + self, + logits, + logit_lengths, + target_lengths, + targets, + ): + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + return _NumpyTransducer.apply( + log_probs, + logit_lengths, + target_lengths, + targets, + self.blank, + ) + + +def compute_with_numpy_transducer(data): + costs = NumpyTransducerLoss( + blank=data["blank"], + )( + logits=data["logits"], + logit_lengths=data["logit_lengths"], + target_lengths=data["target_lengths"], + targets=data["targets"], + ) + + loss = torch.sum(costs) + loss.backward() + costs = costs.cpu() + gradients = data["logits"].saved_grad.cpu() + return costs, gradients + + +def compute_with_pytorch_transducer(data): + costs = rnnt_loss( + logits=data["logits"], + logit_lengths=data["logit_lengths"], + target_lengths=data["target_lengths"], + targets=data["targets"], + blank=data["blank"], + reduction="none", + ) + + loss = torch.sum(costs) + loss.backward() + costs = costs.cpu() + gradients = data["logits"].saved_grad.cpu() + return costs, gradients + + +def get_basic_data(device): + # Example provided + # in 6f73a2513dc784c59eec153a45f40bc528355b18 + # of https://github.com/HawkAaron/warp-transducer + + logits = torch.tensor( + [ + [ + [ + [0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.6, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.8, 0.1], + ], + [ + [0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.1, 0.1], + [0.7, 0.1, 0.2, 0.1, 0.1], + ], + ] + ], + dtype=torch.float32, + device=device, + ) + targets = torch.tensor([[1, 2]], dtype=torch.int, device=device) + logit_lengths = torch.tensor([2], dtype=torch.int, device=device) + target_lengths = torch.tensor([2], dtype=torch.int, device=device) + + logits.requires_grad_(True) + + return logits, targets, logit_lengths, target_lengths + + +def get_B1_T10_U3_D4_data( + random=False, + dtype=torch.float32, + device=CPU_DEVICE, +): + B, T, U, D = 2, 10, 3, 4 + + logits = torch.rand(B, T, U, D, dtype=dtype, device=device) + if not random: + logits.fill_(0.1) + logits.requires_grad_(True) + + def grad_hook(grad): + logits.saved_grad = grad.clone() + logits.register_hook(grad_hook) + + data = {} + data["logits"] = logits + data["logit_lengths"] = torch.tensor([10, 10], dtype=torch.int32, device=device) + data["target_lengths"] = torch.tensor([2, 2], dtype=torch.int32, device=device) + data["targets"] = torch.tensor([[1, 2], [1, 2]], dtype=torch.int32, device=device) + data["blank"] = 0 + + return data + + +def get_B1_T2_U3_D5_data(dtype=torch.float32, device=CPU_DEVICE): + logits = torch.tensor( + [ + 0.1, + 0.6, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.6, + 0.1, + 0.1, + 0.1, + 0.1, + 0.2, + 0.8, + 0.1, + 0.1, + 0.6, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.2, + 0.1, + 0.1, + 0.7, + 0.1, + 0.2, + 0.1, + 0.1, + ], + dtype=dtype, + device=device, + ).reshape(1, 2, 3, 5) + logits.requires_grad_(True) + + def grad_hook(grad): + logits.saved_grad = grad.clone() + logits.register_hook(grad_hook) + + targets = torch.tensor([[1, 2]], dtype=torch.int32, device=device) + logit_lengths = torch.tensor([2], dtype=torch.int32, device=device) + target_lengths = torch.tensor([2], dtype=torch.int32, device=device) + + blank = -1 + + ref_costs = torch.tensor([5.09566688538], dtype=dtype) + ref_gradients = torch.tensor( + [ + 0.17703132, + -0.39992708, + 0.17703132, + 0.17703132, + -0.13116692, + 0.12247062, + 0.12247062, + -0.181684, + 0.12247062, + -0.1857276, + 0.06269141, + 0.06269141, + 0.06928471, + 0.12624498, + -0.32091248, + 0.05456069, + -0.2182428, + 0.05456069, + 0.05456069, + 0.05456069, + 0.12073967, + 0.12073967, + -0.48295838, + 0.12073967, + 0.12073967, + 0.30741188, + 0.16871123, + 0.18645471, + 0.16871123, + -0.83128875, + ], + dtype=dtype, + ).reshape(1, 2, 3, 5) + + data = { + "logits": logits, + "targets": targets, + "logit_lengths": logit_lengths, + "target_lengths": target_lengths, + "blank": blank, + } + + return data, ref_costs, ref_gradients + + +def get_B2_T4_U3_D3_data(dtype=torch.float32, device=CPU_DEVICE): + # Test from D21322854 + logits = torch.tensor( + [ + 0.065357, + 0.787530, + 0.081592, + 0.529716, + 0.750675, + 0.754135, + 0.609764, + 0.868140, + 0.622532, + 0.668522, + 0.858039, + 0.164539, + 0.989780, + 0.944298, + 0.603168, + 0.946783, + 0.666203, + 0.286882, + 0.094184, + 0.366674, + 0.736168, + 0.166680, + 0.714154, + 0.399400, + 0.535982, + 0.291821, + 0.612642, + 0.324241, + 0.800764, + 0.524106, + 0.779195, + 0.183314, + 0.113745, + 0.240222, + 0.339470, + 0.134160, + 0.505562, + 0.051597, + 0.640290, + 0.430733, + 0.829473, + 0.177467, + 0.320700, + 0.042883, + 0.302803, + 0.675178, + 0.569537, + 0.558474, + 0.083132, + 0.060165, + 0.107958, + 0.748615, + 0.943918, + 0.486356, + 0.418199, + 0.652408, + 0.024243, + 0.134582, + 0.366342, + 0.295830, + 0.923670, + 0.689929, + 0.741898, + 0.250005, + 0.603430, + 0.987289, + 0.592606, + 0.884672, + 0.543450, + 0.660770, + 0.377128, + 0.358021, + ], + dtype=dtype, + device=device, + ).reshape(2, 4, 3, 3) + logits.requires_grad_(True) + + def grad_hook(grad): + logits.saved_grad = grad.clone() + logits.register_hook(grad_hook) + + targets = torch.tensor([[1, 2], [1, 1]], dtype=torch.int32, device=device) + logit_lengths = torch.tensor([4, 4], dtype=torch.int32, device=device) + target_lengths = torch.tensor([2, 2], dtype=torch.int32, device=device) + + blank = 0 + + ref_costs = torch.tensor([4.2806528590890736, 3.9384369822503591], dtype=dtype) + + ref_gradients = torch.tensor( + [ + -0.186844, + -0.062555, + 0.249399, + -0.203377, + 0.202399, + 0.000977, + -0.141016, + 0.079123, + 0.061893, + -0.011552, + -0.081280, + 0.092832, + -0.154257, + 0.229433, + -0.075176, + -0.246593, + 0.146405, + 0.100188, + -0.012918, + -0.061593, + 0.074512, + -0.055986, + 0.219831, + -0.163845, + -0.497627, + 0.209240, + 0.288387, + 0.013605, + -0.030220, + 0.016615, + 0.113925, + 0.062781, + -0.176706, + -0.667078, + 0.367659, + 0.299419, + -0.356344, + -0.055347, + 0.411691, + -0.096922, + 0.029459, + 0.067463, + -0.063518, + 0.027654, + 0.035863, + -0.154499, + -0.073942, + 0.228441, + -0.166790, + -0.000088, + 0.166878, + -0.172370, + 0.105565, + 0.066804, + 0.023875, + -0.118256, + 0.094381, + -0.104707, + -0.108934, + 0.213642, + -0.369844, + 0.180118, + 0.189726, + 0.025714, + -0.079462, + 0.053748, + 0.122328, + -0.238789, + 0.116460, + -0.598687, + 0.302203, + 0.296484, + ], + dtype=dtype, + ).reshape(2, 4, 3, 3) + + data = { + "logits": logits, + "targets": targets, + "logit_lengths": logit_lengths, + "target_lengths": target_lengths, + "blank": blank, + } + + return data, ref_costs, ref_gradients + + +def get_random_data( + max_B=8, + max_T=128, + max_U=32, + max_D=40, + blank=-1, + dtype=torch.float32, + device=CPU_DEVICE, + seed=None, +): + if seed is not None: + torch.manual_seed(seed=seed) + + if blank != -1: + raise ValueError("blank != -1 is not supported yet.") + + random.seed(0) + B = random.randint(1, max_B - 1) + T = random.randint(5, max_T - 1) + U = random.randint(5, max_U - 1) + D = random.randint(2, max_D - 1) + + logit_lengths = torch.randint(low=5, high=T + 1, size=(B,), dtype=torch.int32, device=device) + target_lengths = torch.randint(low=5, high=U + 1, size=(B,), dtype=torch.int32, device=device) + max_src_length = torch.max(logit_lengths) + max_tgt_length = torch.max(target_lengths) + + targets = torch.randint( + low=0, high=D - 1, size=(B, max_tgt_length), dtype=torch.int32, device=device + ) + logits = torch.rand( + size=(B, max_src_length, max_tgt_length + 1, D), + dtype=dtype, + device=device, + ).requires_grad_(True) + + def grad_hook(grad): + logits.saved_grad = grad.clone() + logits.register_hook(grad_hook) + + return { + "logits": logits, + "targets": targets, + "logit_lengths": logit_lengths, + "target_lengths": target_lengths, + "blank": blank, + } + + +def skipIfNoRNNT(test_item): + try: + torch.ops.torchaudio.rnnt_loss + return test_item + except RuntimeError: + return unittest.skip("torchaudio C++ extension is not compiled with RNN transducer loss") diff --git a/test/torchaudio_unittest/common_utils/sox_utils.py b/test/torchaudio_unittest/common_utils/sox_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e72b10e0675656aab5b0e929c745d7390e534d38 --- /dev/null +++ b/test/torchaudio_unittest/common_utils/sox_utils.py @@ -0,0 +1,106 @@ +import sys +import subprocess +import warnings + + +def get_encoding(dtype): + encodings = { + 'float32': 'floating-point', + 'int32': 'signed-integer', + 'int16': 'signed-integer', + 'uint8': 'unsigned-integer', + } + return encodings[dtype] + + +def get_bit_depth(dtype): + bit_depths = { + 'float32': 32, + 'int32': 32, + 'int16': 16, + 'uint8': 8, + } + return bit_depths[dtype] + + +def gen_audio_file( + path, sample_rate, num_channels, + *, encoding=None, bit_depth=None, compression=None, attenuation=None, duration=1, comment_file=None, +): + """Generate synthetic audio file with `sox` command.""" + if path.endswith('.wav'): + warnings.warn('Use get_wav_data and save_wav to generate wav file for accurate result.') + command = [ + 'sox', + '-V3', # verbose + '--no-dither', # disable automatic dithering + '-R', + # -R is supposed to be repeatable, though the implementation looks suspicious + # and not setting the seed to a fixed value. + # https://fossies.org/dox/sox-14.4.2/sox_8c_source.html + # search "sox_globals.repeatable" + ] + if bit_depth is not None: + command += ['--bits', str(bit_depth)] + command += [ + '--rate', str(sample_rate), + '--null', # no input + '--channels', str(num_channels), + ] + if compression is not None: + command += ['--compression', str(compression)] + if bit_depth is not None: + command += ['--bits', str(bit_depth)] + if encoding is not None: + command += ['--encoding', str(encoding)] + if comment_file is not None: + command += ['--comment-file', str(comment_file)] + command += [ + str(path), + 'synth', str(duration), # synthesizes for the given duration [sec] + 'sawtooth', '1', + # saw tooth covers the both ends of value range, which is a good property for test. + # similar to linspace(-1., 1.) + # this introduces bigger boundary effect than sine when converted to mp3 + ] + if attenuation is not None: + command += ['vol', f'-{attenuation}dB'] + print(' '.join(command), file=sys.stderr) + subprocess.run(command, check=True) + + +def convert_audio_file( + src_path, dst_path, + *, encoding=None, bit_depth=None, compression=None): + """Convert audio file with `sox` command.""" + command = ['sox', '-V3', '--no-dither', '-R', str(src_path)] + if encoding is not None: + command += ['--encoding', str(encoding)] + if bit_depth is not None: + command += ['--bits', str(bit_depth)] + if compression is not None: + command += ['--compression', str(compression)] + command += [dst_path] + print(' '.join(command), file=sys.stderr) + subprocess.run(command, check=True) + + +def _flattern(effects): + if not effects: + return effects + if isinstance(effects[0], str): + return effects + return [item for sublist in effects for item in sublist] + + +def run_sox_effect(input_file, output_file, effect, *, output_sample_rate=None, output_bitdepth=None): + """Run sox effects""" + effect = _flattern(effect) + command = ['sox', '-V', '--no-dither', input_file] + if output_bitdepth: + command += ['--bits', str(output_bitdepth)] + command += [output_file] + effect + if output_sample_rate: + command += ['rate', str(output_sample_rate)] + print(' '.join(command)) + subprocess.run(command, check=True) diff --git a/test/torchaudio_unittest/common_utils/wav_utils.py b/test/torchaudio_unittest/common_utils/wav_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b4b7944805e8ccde7fbb0071d23b5943ee223dab --- /dev/null +++ b/test/torchaudio_unittest/common_utils/wav_utils.py @@ -0,0 +1,92 @@ +from typing import Optional + +import torch +import scipy.io.wavfile + + +def normalize_wav(tensor: torch.Tensor) -> torch.Tensor: + if tensor.dtype == torch.float32: + pass + elif tensor.dtype == torch.int32: + tensor = tensor.to(torch.float32) + tensor[tensor > 0] /= 2147483647. + tensor[tensor < 0] /= 2147483648. + elif tensor.dtype == torch.int16: + tensor = tensor.to(torch.float32) + tensor[tensor > 0] /= 32767. + tensor[tensor < 0] /= 32768. + elif tensor.dtype == torch.uint8: + tensor = tensor.to(torch.float32) - 128 + tensor[tensor > 0] /= 127. + tensor[tensor < 0] /= 128. + return tensor + + +def get_wav_data( + dtype: str, + num_channels: int, + *, + num_frames: Optional[int] = None, + normalize: bool = True, + channels_first: bool = True, +): + """Generate linear signal of the given dtype and num_channels + + Data range is + [-1.0, 1.0] for float32, + [-2147483648, 2147483647] for int32 + [-32768, 32767] for int16 + [0, 255] for uint8 + + num_frames allow to change the linear interpolation parameter. + Default values are 256 for uint8, else 1 << 16. + 1 << 16 as default is so that int16 value range is completely covered. + """ + dtype_ = getattr(torch, dtype) + + if num_frames is None: + if dtype == 'uint8': + num_frames = 256 + else: + num_frames = 1 << 16 + + if dtype == 'uint8': + base = torch.linspace(0, 255, num_frames, dtype=dtype_) + elif dtype == 'int8': + base = torch.linspace(-128, 127, num_frames, dtype=dtype_) + elif dtype == 'float32': + base = torch.linspace(-1., 1., num_frames, dtype=dtype_) + elif dtype == 'float64': + base = torch.linspace(-1., 1., num_frames, dtype=dtype_) + elif dtype == 'int32': + base = torch.linspace(-2147483648, 2147483647, num_frames, dtype=dtype_) + elif dtype == 'int16': + base = torch.linspace(-32768, 32767, num_frames, dtype=dtype_) + else: + raise NotImplementedError(f'Unsupported dtype {dtype}') + data = base.repeat([num_channels, 1]) + if not channels_first: + data = data.transpose(1, 0) + if normalize: + data = normalize_wav(data) + return data + + +def load_wav(path: str, normalize=True, channels_first=True) -> torch.Tensor: + """Load wav file without torchaudio""" + sample_rate, data = scipy.io.wavfile.read(path) + data = torch.from_numpy(data.copy()) + if data.ndim == 1: + data = data.unsqueeze(1) + if normalize: + data = normalize_wav(data) + if channels_first: + data = data.transpose(1, 0) + return data, sample_rate + + +def save_wav(path, data, sample_rate, channels_first=True): + """Save wav file without torchaudio""" + if channels_first: + data = data.transpose(1, 0) + scipy.io.wavfile.write(path, sample_rate, data.numpy()) diff --git a/test/torchaudio_unittest/compliance_kaldi_test.py b/test/torchaudio_unittest/compliance_kaldi_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f4e79d4228fbe17f7b2d8c1e0c32f2b2a33c62db --- /dev/null +++ b/test/torchaudio_unittest/compliance_kaldi_test.py @@ -0,0 +1,76 @@ +import torch +import torchaudio.compliance.kaldi as kaldi + +from torchaudio_unittest import common_utils + + +def extract_window(window, wave, f, frame_length, frame_shift, snip_edges): + # just a copy of ExtractWindow from feature-window.cc in python + def first_sample_of_frame(frame, window_size, window_shift, snip_edges): + if snip_edges: + return frame * window_shift + else: + midpoint_of_frame = frame * window_shift + window_shift // 2 + beginning_of_frame = midpoint_of_frame - window_size // 2 + return beginning_of_frame + + sample_offset = 0 + num_samples = sample_offset + wave.size(0) + start_sample = first_sample_of_frame(f, frame_length, frame_shift, snip_edges) + end_sample = start_sample + frame_length + + if snip_edges: + assert(start_sample >= sample_offset and end_sample <= num_samples) + else: + assert(sample_offset == 0 or start_sample >= sample_offset) + + wave_start = start_sample - sample_offset + wave_end = wave_start + frame_length + if wave_start >= 0 and wave_end <= wave.size(0): + window[f, :] = wave[wave_start:(wave_start + frame_length)] + else: + wave_dim = wave.size(0) + for s in range(frame_length): + s_in_wave = s + wave_start + while s_in_wave < 0 or s_in_wave >= wave_dim: + if s_in_wave < 0: + s_in_wave = - s_in_wave - 1 + else: + s_in_wave = 2 * wave_dim - 1 - s_in_wave + window[f, s] = wave[s_in_wave] + + +class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase): + + def _test_get_strided_helper(self, num_samples, window_size, window_shift, snip_edges): + waveform = torch.arange(num_samples).float() + output = kaldi._get_strided(waveform, window_size, window_shift, snip_edges) + + # from NumFrames in feature-window.cc + n = window_size + if snip_edges: + m = 0 if num_samples < window_size else 1 + (num_samples - window_size) // window_shift + else: + m = (num_samples + (window_shift // 2)) // window_shift + + self.assertTrue(output.dim() == 2) + self.assertTrue(output.shape[0] == m and output.shape[1] == n) + + window = torch.empty((m, window_size)) + + for r in range(m): + extract_window(window, waveform, r, window_size, window_shift, snip_edges) + self.assertEqual(window, output) + + def test_get_strided(self): + # generate any combination where 0 < window_size <= num_samples and + # 0 < window_shift. + for num_samples in range(1, 20): + for window_size in range(1, num_samples + 1): + for window_shift in range(1, 2 * num_samples + 1): + for snip_edges in range(0, 2): + self._test_get_strided_helper(num_samples, window_size, window_shift, snip_edges) + + def test_mfcc_empty(self): + # Passing in an empty tensor should result in an error + self.assertRaises(AssertionError, kaldi.mfcc, torch.empty(0)) diff --git a/test/torchaudio_unittest/datasets/__init__.py b/test/torchaudio_unittest/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/torchaudio_unittest/datasets/cmuarctic_test.py b/test/torchaudio_unittest/datasets/cmuarctic_test.py new file mode 100644 index 0000000000000000000000000000000000000000..10ff7668061a394745ae238398b8e79e99690ee8 --- /dev/null +++ b/test/torchaudio_unittest/datasets/cmuarctic_test.py @@ -0,0 +1,84 @@ +import os +from pathlib import Path + +from torchaudio.datasets import cmuarctic + +from torchaudio_unittest.common_utils import ( + TempDirMixin, + TorchaudioTestCase, + get_whitenoise, + save_wav, + normalize_wav, +) + + +def get_mock_dataset(root_dir): + """ + root_dir: directory to the mocked dataset + """ + mocked_data = [] + sample_rate = 16000 + transcript = "This is a test transcript." + + base_dir = os.path.join(root_dir, "ARCTIC", "cmu_us_aew_arctic") + txt_dir = os.path.join(base_dir, "etc") + os.makedirs(txt_dir, exist_ok=True) + txt_file = os.path.join(txt_dir, "txt.done.data") + audio_dir = os.path.join(base_dir, "wav") + os.makedirs(audio_dir, exist_ok=True) + + seed = 42 + with open(txt_file, "w") as txt: + for c in ["a", "b"]: + for i in range(5): + utterance_id = f"arctic_{c}{i:04d}" + path = os.path.join(audio_dir, f"{utterance_id}.wav") + data = get_whitenoise( + sample_rate=sample_rate, + duration=3, + n_channels=1, + dtype="int16", + seed=seed, + ) + save_wav(path, data, sample_rate) + sample = ( + normalize_wav(data), + sample_rate, + transcript, + utterance_id.split("_")[1], + ) + mocked_data.append(sample) + txt.write(f'( {utterance_id} "{transcript}" )\n') + seed += 1 + return mocked_data + + +class TestCMUARCTIC(TempDirMixin, TorchaudioTestCase): + backend = "default" + + root_dir = None + samples = [] + + @classmethod + def setUpClass(cls): + cls.root_dir = cls.get_base_temp_dir() + cls.samples = get_mock_dataset(cls.root_dir) + + def _test_cmuarctic(self, dataset): + n_ite = 0 + for i, (waveform, sample_rate, transcript, utterance_id) in enumerate(dataset): + expected_sample = self.samples[i] + assert sample_rate == expected_sample[1] + assert transcript == expected_sample[2] + assert utterance_id == expected_sample[3] + self.assertEqual(expected_sample[0], waveform, atol=5e-5, rtol=1e-8) + n_ite += 1 + assert n_ite == len(self.samples) + + def test_cmuarctic_str(self): + dataset = cmuarctic.CMUARCTIC(self.root_dir) + self._test_cmuarctic(dataset) + + def test_cmuarctic_path(self): + dataset = cmuarctic.CMUARCTIC(Path(self.root_dir)) + self._test_cmuarctic(dataset) diff --git a/test/torchaudio_unittest/datasets/cmudict_test.py b/test/torchaudio_unittest/datasets/cmudict_test.py new file mode 100644 index 0000000000000000000000000000000000000000..da645346ba7c646531237c63f47d5e9fdccbd208 --- /dev/null +++ b/test/torchaudio_unittest/datasets/cmudict_test.py @@ -0,0 +1,218 @@ +import os +from pathlib import Path + +from torchaudio.datasets import CMUDict + +from torchaudio_unittest.common_utils import ( + TempDirMixin, + TorchaudioTestCase, +) + + +def get_mock_dataset(root_dir, return_punc=False): + """ + root_dir: directory to the mocked dataset + """ + header = [ + ";;; # CMUdict -- Major Version: 0.07", + ";;; ", + ";;; # $HeadURL$", + ] + + puncs = [ + "!EXCLAMATION-POINT EH2 K S K L AH0 M EY1 SH AH0 N P OY2 N T", + "\"CLOSE-QUOTE K L OW1 Z K W OW1 T", + "#HASH-MARK HH AE1 M AA2 R K", + "%PERCENT P ER0 S EH1 N T", + "&ERSAND AE1 M P ER0 S AE2 N D", + "'END-INNER-QUOTE EH1 N D IH1 N ER0 K W OW1 T", + "(BEGIN-PARENS B IH0 G IH1 N P ER0 EH1 N Z", + ")CLOSE-PAREN K L OW1 Z P ER0 EH1 N", + "+PLUS P L UH1 S", + ",COMMA K AA1 M AH0", + "--DASH D AE1 SH", + "!EXCLAMATION-POINT EH2 K S K L AH0 M EY1 SH AH0 N P OY2 N T", + "/SLASH S L AE1 SH", + ":COLON K OW1 L AH0 N", + ";SEMI-COLON S EH1 M IY0 K OW1 L AH0 N", + "?QUESTION-MARK K W EH1 S CH AH0 N M AA1 R K", + "{BRACE B R EY1 S", + "}CLOSE-BRACE K L OW1 Z B R EY1 S", + "...ELLIPSIS IH2 L IH1 P S IH0 S", + ] + + punc_outputs = [ + "!", + "\"", + "#", + "%", + "&", + "'", + "(", + ")", + "+", + ",", + "--", + "!", + "/", + ":", + ";", + "?", + "{", + "}", + "...", + ] + + words = [ + "3-D TH R IY1 D IY2", + "'BOUT B AW1 T", + "'CAUSE K AH0 Z", + "'TWAS T W AH1 Z", + "A AH0", + "B B IY1", + "C S IY1", + "D D IY1", + "E IY1", + "F EH1 F", + "G JH IY1", + "H EY1 CH", + "I AY1", + "J JH EY1", + "K K EY1", + "L EH1 L", + "M EH1 M", + "N EH1 N", + "O OW1", + "P P IY1", + "Q K Y UW1", + "R AA1 R", + "S EH1 S", + "T T IY1", + "U Y UW1", + "V V IY1", + "X EH1 K S", + "Y W AY1", + "Z Z IY1", + ] + + mocked_symbols = [ + "AA1", + "AA2", + "AE1", + "AE2", + "AH0", + "AH1", + "AY1", + "B", + "CH", + "D", + "EH1", + "EH2", + "ER0", + "EY1", + "F", + "G", + "HH", + "IH0", + "IH1", + "IY0", + "IY1", + "IY2", + "JH", + "K", + "L", + "M", + "N", + "OW1", + "OY2", + "P", + "R", + "S", + "SH", + "T", + "TH", + "UH1", + "UW0", + "UW1", + "V", + "W", + "Y", + "Z", + ] + + dict_file = os.path.join(root_dir, "cmudict-0.7b") + symbol_file = os.path.join(root_dir, "cmudict-0.7b.symbols") + + with open(dict_file, "w") as fileobj: + for section in [header, puncs, words]: + for line in section: + fileobj.write(line) + fileobj.write("\n") + + with open(symbol_file, "w") as txt: + txt.write("\n".join(mocked_symbols)) + + mocked_data = [] + + if return_punc: + for i, ent in enumerate(puncs): + _, phones = ent.split(" ") + mocked_data.append((punc_outputs[i], phones.split(" "))) + + for ent in words: + word, phones = ent.split(" ") + mocked_data.append((word, phones.split(" "))) + + return mocked_data + + +class TestCMUDict(TempDirMixin, TorchaudioTestCase): + root_dir = None + root_punc_dir = None + samples = [] + punc_samples = [] + + @classmethod + def setUpClass(cls): + cls.root_dir = os.path.join(cls.get_base_temp_dir(), "normal") + os.mkdir(cls.root_dir) + cls.samples = get_mock_dataset(cls.root_dir) + cls.root_punc_dir = os.path.join(cls.get_base_temp_dir(), "punc") + os.mkdir(cls.root_punc_dir) + cls.punc_samples = get_mock_dataset(cls.root_punc_dir, return_punc=True) + + def _test_cmudict(self, dataset): + """Test if the dataset is reading the mocked data correctly.""" + n_item = 0 + for i, (word, phones) in enumerate(dataset): + expected_word, expected_phones = self.samples[i] + assert word == expected_word + assert phones == expected_phones + n_item += 1 + assert n_item == len(self.samples) + + def _test_punc_cmudict(self, dataset): + """Test if the dataset is reading the mocked data with punctuations correctly.""" + n_item = 0 + for i, (word, phones) in enumerate(dataset): + expected_word, expected_phones = self.punc_samples[i] + assert word == expected_word + assert phones == expected_phones + n_item += 1 + assert n_item == len(self.punc_samples) + + def test_cmuarctic_path_with_punctuation(self): + dataset = CMUDict(Path(self.root_punc_dir), exclude_punctuations=False) + self._test_punc_cmudict(dataset) + + def test_cmuarctic_str_with_punctuation(self): + dataset = CMUDict(self.root_punc_dir, exclude_punctuations=False) + self._test_punc_cmudict(dataset) + + def test_cmuarctic_path(self): + dataset = CMUDict(Path(self.root_punc_dir), exclude_punctuations=True) + self._test_cmudict(dataset) + + def test_cmuarctic_str(self): + dataset = CMUDict(self.root_punc_dir, exclude_punctuations=True) + self._test_cmudict(dataset) diff --git a/test/torchaudio_unittest/datasets/commonvoice_test.py b/test/torchaudio_unittest/datasets/commonvoice_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4d7c269f2a3713c6842fe92b794c69759dc0ae61 --- /dev/null +++ b/test/torchaudio_unittest/datasets/commonvoice_test.py @@ -0,0 +1,148 @@ +import csv +import os +from pathlib import Path +from typing import Tuple, Dict + +from torch import Tensor +from torchaudio_unittest.common_utils import ( + TempDirMixin, + TorchaudioTestCase, + get_whitenoise, + save_wav, + normalize_wav, +) + +from torchaudio.datasets import COMMONVOICE + +_ORIGINAL_EXT_AUDIO = COMMONVOICE._ext_audio +_SAMPLE_RATE = 48000 +_HEADERS = [u"client_ids", u"path", u"sentence", u"up_votes", u"down_votes", u"age", u"gender", u"accent"] +_EN_TRAIN_CSV_CONTENTS = [ + ["9d16c5d980247861130e0480e2719f448be73d86a496c36d01a477cbdecd8cfd1399403d7a77bf458d211a70711b2da0845c", + "common_voice_en_18885784.wav", + "He was accorded a State funeral, and was buried in Drayton and Toowoomba Cemetery.", "2", "0", "", "", + ""], + ["c82eb9291328620f06025a1f8112b909099e447e485e99236cb87df008650250e79fea5ca772061fb6a370830847b9c44d20", + "common_voice_en_556542.wav", "Once more into the breach", "2", "0", "thirties", "male", "us"], + ["f74d880c5ad4c5917f314a604d3fc4805159d255796fb9f8defca35333ecc002bdf53dc463503c12674ea840b21b4a507b7c", + "common_voice_en_18607573.wav", + "Caddy, show Miss Clare and Miss Summerson their rooms.", "2", "0", "twenties", "male", "canada"], +] + +_FR_TRAIN_CSV_CONTENTS = [ + [ + "a2e8e1e1cc74d08c92a53d7b9ff84e077eb90410edd85b8882f16fd037cecfcb6a19413c6c63ce6458cfea9579878fa91cef" + "18343441c601cae0597a4b0d3144", + "89e67e7682b36786a0b4b4022c4d42090c86edd96c78c12d30088e62522b8fe466ea4912e6a1055dfb91b296a0743e0a2bbe" + "16cebac98ee5349e3e8262cb9329", + "Or sur ce point nous n’avons aucune réponse de votre part.", "2", "0", "twenties", "male", "france"], + [ + "a2e8e1e1cc74d08c92a53d7b9ff84e077eb90410edd85b8882f16fd037cecfcb6a19413c6c63ce6458cfea9579878fa91cef18" + "343441c601cae0597a4b0d3144", + "87d71819a26179e93acfee149d0b21b7bf5e926e367d80b2b3792d45f46e04853a514945783ff764c1fc237b4eb0ee2b0a7a7" + "cbd395acbdfcfa9d76a6e199bbd", + "Monsieur de La Verpillière, laissez parler le ministre", "2", "0", "twenties", "male", "france"], + +] + + +def get_mock_dataset(root_dir, train_csv_contents, ext_audio) -> Tuple[Tensor, int, Dict[str, str]]: + """ + prepares mocked dataset + """ + mocked_data = [] + # Note: extension is changed to wav for the sake of test + # Note: the first content is missing values for `age`, `gender` and `accent` as in the original data. + # Tsv file name difference does not mean different subset, testing as a whole dataset here + tsv_filename = os.path.join(root_dir, "train.tsv") + audio_base_path = os.path.join(root_dir, "clips") + os.makedirs(audio_base_path, exist_ok=True) + with open(tsv_filename, "w", newline='') as tsv: + writer = csv.writer(tsv, delimiter='\t') + writer.writerow(_HEADERS) + for i, content in enumerate(train_csv_contents): + content[2] = str(content[2].encode("utf-8")) + writer.writerow(content) + if not content[1].endswith(ext_audio): + audio_path = os.path.join(audio_base_path, content[1] + ext_audio) + else: + audio_path = os.path.join(audio_base_path, content[1]) + + data = get_whitenoise(sample_rate=_SAMPLE_RATE, duration=1, n_channels=1, seed=i, dtype='float32') + save_wav(audio_path, data, _SAMPLE_RATE) + # Append data entry + mocked_data.append((normalize_wav(data), _SAMPLE_RATE, dict(zip(_HEADERS, content)))) + return mocked_data + + +def get_mock_dataset_en(root_dir, ext_audio) -> Tuple[Tensor, int, Dict[str, str]]: + """ + prepares english mocked dataset + """ + return get_mock_dataset(root_dir, _EN_TRAIN_CSV_CONTENTS, ext_audio) + + +def get_mock_dataset_fr(root_dir, ext_audio) -> Tuple[Tensor, int, Dict[str, str]]: + """ + prepares french mocked dataset + """ + return get_mock_dataset(root_dir, _FR_TRAIN_CSV_CONTENTS, ext_audio) + + +class BaseTestCommonVoice(TempDirMixin): + root_dir = None + data = None + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.root_dir = cls.get_base_temp_dir() + COMMONVOICE._ext_audio = ".wav" + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + COMMONVOICE._ext_audio = _ORIGINAL_EXT_AUDIO + + def _test_commonvoice(self, dataset): + n_ite = 0 + for i, (waveform, sample_rate, dictionary) in enumerate(dataset): + expected_dictionary = self.data[i][2] + expected_data = self.data[i][0] + self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8) + assert sample_rate == _SAMPLE_RATE + assert dictionary == expected_dictionary + n_ite += 1 + assert n_ite == len(self.data) + + +class TestCommonVoiceEN(BaseTestCommonVoice, TorchaudioTestCase): + backend = 'default' + root_dir = None + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.data = get_mock_dataset_en(cls.root_dir, COMMONVOICE._ext_audio) + + def test_commonvoice_str(self): + dataset = COMMONVOICE(self.root_dir) + self._test_commonvoice(dataset) + + def test_commonvoice_path(self): + dataset = COMMONVOICE(Path(self.root_dir)) + self._test_commonvoice(dataset) + + +class TestCommonVoiceFR(BaseTestCommonVoice, TorchaudioTestCase): + backend = 'default' + root_dir = None + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.data = get_mock_dataset_fr(cls.root_dir, COMMONVOICE._ext_audio) + + def test_commonvoice_str(self): + dataset = COMMONVOICE(self.root_dir) + self._test_commonvoice(dataset) diff --git a/test/torchaudio_unittest/datasets/datasets_test.py b/test/torchaudio_unittest/datasets/datasets_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fda0804ea66a5522e01985f80f5d2849696569ed --- /dev/null +++ b/test/torchaudio_unittest/datasets/datasets_test.py @@ -0,0 +1,15 @@ +from torchaudio.datasets.vctk import VCTK + +from torchaudio_unittest.common_utils import ( + TorchaudioTestCase, + get_asset_path, +) + + +class TestDatasets(TorchaudioTestCase): + backend = 'default' + path = get_asset_path() + + def test_vctk(self): + data = VCTK(self.path) + data[0] diff --git a/test/torchaudio_unittest/datasets/gtzan_test.py b/test/torchaudio_unittest/datasets/gtzan_test.py new file mode 100644 index 0000000000000000000000000000000000000000..838292f55d4da71bcddc2e0ddbc9e77c89264f11 --- /dev/null +++ b/test/torchaudio_unittest/datasets/gtzan_test.py @@ -0,0 +1,127 @@ +import os +from pathlib import Path + +from torchaudio.datasets import gtzan + +from torchaudio_unittest.common_utils import ( + TempDirMixin, + TorchaudioTestCase, + get_whitenoise, + save_wav, + normalize_wav, +) + + +def get_mock_dataset(root_dir): + """ + root_dir: directory to the mocked dataset + """ + mocked_samples = [] + mocked_training = [] + mocked_validation = [] + mocked_testing = [] + sample_rate = 22050 + + seed = 0 + for genre in gtzan.gtzan_genres: + base_dir = os.path.join(root_dir, 'genres', genre) + os.makedirs(base_dir, exist_ok=True) + for i in range(100): + filename = f'{genre}.{i:05d}' + path = os.path.join(base_dir, f'{filename}.wav') + data = get_whitenoise(sample_rate=sample_rate, duration=0.01, n_channels=1, dtype='int16', seed=seed) + save_wav(path, data, sample_rate) + sample = (normalize_wav(data), sample_rate, genre) + mocked_samples.append(sample) + if filename in gtzan.filtered_test: + mocked_testing.append(sample) + if filename in gtzan.filtered_train: + mocked_training.append(sample) + if filename in gtzan.filtered_valid: + mocked_validation.append(sample) + seed += 1 + return (mocked_samples, mocked_training, mocked_validation, mocked_testing) + + +class TestGTZAN(TempDirMixin, TorchaudioTestCase): + backend = 'default' + + root_dir = None + samples = [] + training = [] + validation = [] + testing = [] + + @classmethod + def setUpClass(cls): + cls.root_dir = cls.get_base_temp_dir() + mocked_data = get_mock_dataset(cls.root_dir) + cls.samples = mocked_data[0] + cls.training = mocked_data[1] + cls.validation = mocked_data[2] + cls.testing = mocked_data[3] + + def test_no_subset(self): + dataset = gtzan.GTZAN(self.root_dir) + + n_ite = 0 + for i, (waveform, sample_rate, label) in enumerate(dataset): + self.assertEqual(waveform, self.samples[i][0], atol=5e-5, rtol=1e-8) + assert sample_rate == self.samples[i][1] + assert label == self.samples[i][2] + n_ite += 1 + assert n_ite == len(self.samples) + + def _test_training(self, dataset): + n_ite = 0 + for i, (waveform, sample_rate, label) in enumerate(dataset): + self.assertEqual(waveform, self.training[i][0], atol=5e-5, rtol=1e-8) + assert sample_rate == self.training[i][1] + assert label == self.training[i][2] + n_ite += 1 + assert n_ite == len(self.training) + + def _test_validation(self, dataset): + n_ite = 0 + for i, (waveform, sample_rate, label) in enumerate(dataset): + self.assertEqual(waveform, self.validation[i][0], atol=5e-5, rtol=1e-8) + assert sample_rate == self.validation[i][1] + assert label == self.validation[i][2] + n_ite += 1 + assert n_ite == len(self.validation) + + def _test_testing(self, dataset): + n_ite = 0 + for i, (waveform, sample_rate, label) in enumerate(dataset): + self.assertEqual(waveform, self.testing[i][0], atol=5e-5, rtol=1e-8) + assert sample_rate == self.testing[i][1] + assert label == self.testing[i][2] + n_ite += 1 + assert n_ite == len(self.testing) + + def test_training_str(self): + train_dataset = gtzan.GTZAN(self.root_dir, subset='training') + self._test_training(train_dataset) + + def test_validation_str(self): + val_dataset = gtzan.GTZAN(self.root_dir, subset='validation') + self._test_validation(val_dataset) + + def test_testing_str(self): + test_dataset = gtzan.GTZAN(self.root_dir, subset='testing') + self._test_testing(test_dataset) + + def test_training_path(self): + root_dir = Path(self.root_dir) + train_dataset = gtzan.GTZAN(root_dir, subset='training') + self._test_training(train_dataset) + + def test_validation_path(self): + root_dir = Path(self.root_dir) + val_dataset = gtzan.GTZAN(root_dir, subset='validation') + self._test_validation(val_dataset) + + def test_testing_path(self): + root_dir = Path(self.root_dir) + test_dataset = gtzan.GTZAN(root_dir, subset='testing') + self._test_testing(test_dataset) diff --git a/test/torchaudio_unittest/datasets/librispeech_test.py b/test/torchaudio_unittest/datasets/librispeech_test.py new file mode 100644 index 0000000000000000000000000000000000000000..44e98c1f4750ff663a7be9f38550e554e5244af6 --- /dev/null +++ b/test/torchaudio_unittest/datasets/librispeech_test.py @@ -0,0 +1,128 @@ +import os +from pathlib import Path + +from torchaudio_unittest.common_utils import ( + TempDirMixin, + TorchaudioTestCase, + get_whitenoise, + save_wav, + normalize_wav, +) + +from torchaudio.datasets import librispeech + +# Used to generate a unique transcript for each dummy audio file +_NUMBERS = [ + 'ZERO', + 'ONE', + 'TWO', + 'THREE', + 'FOUR', + 'FIVE', + 'SIX', + 'SEVEN', + 'EIGHT', + 'NINE' +] + + +def get_mock_dataset(root_dir): + """ + root_dir: directory to the mocked dataset + """ + mocked_data = [] + dataset_dir = os.path.join( + root_dir, librispeech.FOLDER_IN_ARCHIVE, librispeech.URL + ) + os.makedirs(dataset_dir, exist_ok=True) + sample_rate = 16000 # 16kHz + seed = 0 + + for speaker_id in range(5): + speaker_path = os.path.join(dataset_dir, str(speaker_id)) + os.makedirs(speaker_path, exist_ok=True) + + for chapter_id in range(3): + chapter_path = os.path.join(speaker_path, str(chapter_id)) + os.makedirs(chapter_path, exist_ok=True) + trans_content = [] + + for utterance_id in range(10): + filename = f'{speaker_id}-{chapter_id}-{utterance_id:04d}.wav' + path = os.path.join(chapter_path, filename) + + transcript = ' '.join( + [_NUMBERS[x] for x in [speaker_id, chapter_id, utterance_id]] + ) + trans_content.append( + f'{speaker_id}-{chapter_id}-{utterance_id:04d} {transcript}' + ) + + data = get_whitenoise( + sample_rate=sample_rate, + duration=0.01, + n_channels=1, + dtype='float32', + seed=seed + ) + save_wav(path, data, sample_rate) + sample = ( + normalize_wav(data), + sample_rate, + transcript, + speaker_id, + chapter_id, + utterance_id + ) + mocked_data.append(sample) + + seed += 1 + + trans_filename = f'{speaker_id}-{chapter_id}.trans.txt' + trans_path = os.path.join(chapter_path, trans_filename) + with open(trans_path, 'w') as f: + f.write('\n'.join(trans_content)) + return mocked_data + + +class TestLibriSpeech(TempDirMixin, TorchaudioTestCase): + backend = 'default' + + root_dir = None + samples = [] + + @classmethod + def setUpClass(cls): + cls.root_dir = cls.get_base_temp_dir() + cls.samples = get_mock_dataset(cls.root_dir) + + @classmethod + def tearDownClass(cls): + # In case of test failure + librispeech.LIBRISPEECH._ext_audio = '.flac' + + def _test_librispeech(self, dataset): + num_samples = 0 + for i, ( + data, sample_rate, transcript, speaker_id, chapter_id, utterance_id + ) in enumerate(dataset): + self.assertEqual(data, self.samples[i][0], atol=5e-5, rtol=1e-8) + assert sample_rate == self.samples[i][1] + assert transcript == self.samples[i][2] + assert speaker_id == self.samples[i][3] + assert chapter_id == self.samples[i][4] + assert utterance_id == self.samples[i][5] + num_samples += 1 + + assert num_samples == len(self.samples) + librispeech.LIBRISPEECH._ext_audio = '.flac' + + def test_librispeech_str(self): + librispeech.LIBRISPEECH._ext_audio = '.wav' + dataset = librispeech.LIBRISPEECH(self.root_dir) + self._test_librispeech(dataset) + + def test_librispeech_path(self): + librispeech.LIBRISPEECH._ext_audio = '.wav' + dataset = librispeech.LIBRISPEECH(Path(self.root_dir)) + self._test_librispeech(dataset) diff --git a/test/torchaudio_unittest/datasets/libritts_test.py b/test/torchaudio_unittest/datasets/libritts_test.py new file mode 100644 index 0000000000000000000000000000000000000000..32e6dbcf738708c186b268f85f7aa8d8bf71da2b --- /dev/null +++ b/test/torchaudio_unittest/datasets/libritts_test.py @@ -0,0 +1,89 @@ +import os +from pathlib import Path + +from torchaudio_unittest.common_utils import ( + TempDirMixin, + TorchaudioTestCase, + get_whitenoise, + save_wav, + normalize_wav, +) + +from torchaudio.datasets.libritts import LIBRITTS + +_UTTERANCE_IDS = [ + [19, 198, '000000', '000000'], + [26, 495, '000004', '000000'], +] +_ORIGINAL_TEXT = 'this is the original text.' +_NORMALIZED_TEXT = 'this is the normalized text.' + + +def get_mock_dataset(root_dir): + """ + root_dir: directory to the mocked dataset + """ + mocked_data = [] + base_dir = os.path.join(root_dir, 'LibriTTS', 'train-clean-100') + for i, utterance_id in enumerate(_UTTERANCE_IDS): + filename = f'{"_".join(str(u) for u in utterance_id)}.wav' + file_dir = os.path.join(base_dir, str(utterance_id[0]), str(utterance_id[1])) + os.makedirs(file_dir, exist_ok=True) + path = os.path.join(file_dir, filename) + + data = get_whitenoise(sample_rate=24000, duration=2, n_channels=1, dtype='int16', seed=i) + save_wav(path, data, 24000) + mocked_data.append(normalize_wav(data)) + + original_text_filename = f'{"_".join(str(u) for u in utterance_id)}.original.txt' + path_original = os.path.join(file_dir, original_text_filename) + with open(path_original, 'w') as file_: + file_.write(_ORIGINAL_TEXT) + + normalized_text_filename = f'{"_".join(str(u) for u in utterance_id)}.normalized.txt' + path_normalized = os.path.join(file_dir, normalized_text_filename) + with open(path_normalized, 'w') as file_: + file_.write(_NORMALIZED_TEXT) + return mocked_data, _UTTERANCE_IDS, _ORIGINAL_TEXT, _NORMALIZED_TEXT + + +class TestLibriTTS(TempDirMixin, TorchaudioTestCase): + backend = 'default' + + root_dir = None + data = [] + _utterance_ids, _original_text, _normalized_text = [], [], [] + + @classmethod + def setUpClass(cls): + cls.root_dir = cls.get_base_temp_dir() + cls.data, cls._utterance_ids, cls._original_text, cls._normalized_text = get_mock_dataset(cls.root_dir) + + def _test_libritts(self, dataset): + n_ites = 0 + for i, (waveform, + sample_rate, + original_text, + normalized_text, + speaker_id, + chapter_id, + utterance_id) in enumerate(dataset): + expected_ids = self._utterance_ids[i] + expected_data = self.data[i] + self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8) + assert sample_rate == 24000 + assert speaker_id == expected_ids[0] + assert chapter_id == expected_ids[1] + assert original_text == self._original_text + assert normalized_text == self._normalized_text + assert utterance_id == f'{"_".join(str(u) for u in expected_ids[-4:])}' + n_ites += 1 + assert n_ites == len(self._utterance_ids) + + def test_libritts_str(self): + dataset = LIBRITTS(self.root_dir) + self._test_libritts(dataset) + + def test_libritts_path(self): + dataset = LIBRITTS(Path(self.root_dir)) + self._test_libritts(dataset) diff --git a/test/torchaudio_unittest/datasets/ljspeech_test.py b/test/torchaudio_unittest/datasets/ljspeech_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4cf834b226cf8011323404d85279a8d538e028c7 --- /dev/null +++ b/test/torchaudio_unittest/datasets/ljspeech_test.py @@ -0,0 +1,92 @@ +import csv +import os +from pathlib import Path + +from torchaudio_unittest.common_utils import ( + TempDirMixin, + TorchaudioTestCase, + get_whitenoise, + normalize_wav, + save_wav, +) + +from torchaudio.datasets import ljspeech + +_TRANSCRIPTS = [ + "Test transcript 1", + "Test transcript 2", + "Test transcript 3", + "In 1465 Sweynheim and Pannartz began printing in the monastery of Subiaco near Rome," +] + +_NORMALIZED_TRANSCRIPT = [ + "Test transcript one", + "Test transcript two", + "Test transcript three", + "In fourteen sixty-five Sweynheim and Pannartz began printing in the monastery of Subiaco near Rome," +] + + +def get_mock_dataset(root_dir): + """ + root_dir: path to the mocked dataset + """ + mocked_data = [] + base_dir = os.path.join(root_dir, "LJSpeech-1.1") + archive_dir = os.path.join(base_dir, "wavs") + os.makedirs(archive_dir, exist_ok=True) + metadata_path = os.path.join(base_dir, "metadata.csv") + sample_rate = 22050 + + with open(metadata_path, mode="w", newline='') as metadata_file: + metadata_writer = csv.writer( + metadata_file, delimiter="|", quoting=csv.QUOTE_NONE + ) + for i, (transcript, normalized_transcript) in enumerate( + zip(_TRANSCRIPTS, _NORMALIZED_TRANSCRIPT) + ): + fileid = f'LJ001-{i:04d}' + metadata_writer.writerow([fileid, transcript, normalized_transcript]) + filename = fileid + ".wav" + path = os.path.join(archive_dir, filename) + data = get_whitenoise( + sample_rate=sample_rate, duration=1, n_channels=1, dtype="int16", seed=i + ) + save_wav(path, data, sample_rate) + mocked_data.append(normalize_wav(data)) + return mocked_data, _TRANSCRIPTS, _NORMALIZED_TRANSCRIPT + + +class TestLJSpeech(TempDirMixin, TorchaudioTestCase): + backend = "default" + + root_dir = None + data, _transcripts, _normalized_transcript = [], [], [] + + @classmethod + def setUpClass(cls): + cls.root_dir = cls.get_base_temp_dir() + cls.data, cls._transcripts, cls._normalized_transcript = get_mock_dataset(cls.root_dir) + + def _test_ljspeech(self, dataset): + n_ite = 0 + for i, (waveform, sample_rate, transcript, normalized_transcript) in enumerate( + dataset + ): + expected_transcript = self._transcripts[i] + expected_normalized_transcript = self._normalized_transcript[i] + expected_data = self.data[i] + self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8) + assert sample_rate == sample_rate + assert transcript == expected_transcript + assert normalized_transcript == expected_normalized_transcript + n_ite += 1 + assert n_ite == len(self.data) + + def test_ljspeech_str(self): + dataset = ljspeech.LJSPEECH(self.root_dir) + self._test_ljspeech(dataset) + + def test_ljspeech_path(self): + dataset = ljspeech.LJSPEECH(Path(self.root_dir)) + self._test_ljspeech(dataset) diff --git a/test/torchaudio_unittest/datasets/speechcommands_test.py b/test/torchaudio_unittest/datasets/speechcommands_test.py new file mode 100644 index 0000000000000000000000000000000000000000..19a352ee1677ff241129de62a93b2b3035e2e42f --- /dev/null +++ b/test/torchaudio_unittest/datasets/speechcommands_test.py @@ -0,0 +1,161 @@ +import os +from pathlib import Path + +from torchaudio_unittest.common_utils import ( + TempDirMixin, + TorchaudioTestCase, + get_whitenoise, + normalize_wav, + save_wav, +) + +from torchaudio.datasets import speechcommands + +_LABELS = [ + "bed", + "bird", + "cat", + "dog", + "down", + "eight", + "five", + "follow", + "forward", + "four", + "go", + "happy", + "house", + "learn", + "left", + "marvin", + "nine", + "no", + "off", + "on", + "one", + "right", + "seven", + "sheila", + "six", + "stop", + "three", + "tree", + "two", + "up", + "visual", + "wow", + "yes", + "zero", +] + + +def get_mock_dataset(dataset_dir): + """ + dataset_dir: directory to the mocked dataset + """ + mocked_samples = [] + mocked_train_samples = [] + mocked_valid_samples = [] + mocked_test_samples = [] + os.makedirs(dataset_dir, exist_ok=True) + sample_rate = 16000 # 16kHz sample rate + seed = 0 + valid_file = os.path.join(dataset_dir, "validation_list.txt") + test_file = os.path.join(dataset_dir, "testing_list.txt") + with open(valid_file, "w") as valid, open(test_file, "w") as test: + for label in _LABELS: + path = os.path.join(dataset_dir, label) + os.makedirs(path, exist_ok=True) + for j in range(6): + # generate hash ID for speaker + speaker = "{:08x}".format(j) + + for utterance in range(3): + filename = f"{speaker}{speechcommands.HASH_DIVIDER}{utterance}.wav" + file_path = os.path.join(path, filename) + seed += 1 + data = get_whitenoise( + sample_rate=sample_rate, + duration=0.01, + n_channels=1, + dtype="int16", + seed=seed, + ) + save_wav(file_path, data, sample_rate) + sample = ( + normalize_wav(data), + sample_rate, + label, + speaker, + utterance, + ) + mocked_samples.append(sample) + if j < 2: + mocked_train_samples.append(sample) + elif j < 4: + valid.write(f'{label}/{filename}\n') + mocked_valid_samples.append(sample) + elif j < 6: + test.write(f'{label}/{filename}\n') + mocked_test_samples.append(sample) + return mocked_samples, mocked_train_samples, mocked_valid_samples, mocked_test_samples + + +class TestSpeechCommands(TempDirMixin, TorchaudioTestCase): + backend = "default" + + root_dir = None + samples = [] + train_samples = [] + valid_samples = [] + test_samples = [] + + @classmethod + def setUpClass(cls): + cls.root_dir = cls.get_base_temp_dir() + dataset_dir = os.path.join( + cls.root_dir, speechcommands.FOLDER_IN_ARCHIVE, speechcommands.URL + ) + cls.samples, cls.train_samples, cls.valid_samples, cls.test_samples = get_mock_dataset(dataset_dir) + + def _testSpeechCommands(self, dataset, data_samples): + num_samples = 0 + for i, (data, sample_rate, label, speaker_id, utterance_number) in enumerate( + dataset + ): + self.assertEqual(data, data_samples[i][0], atol=5e-5, rtol=1e-8) + assert sample_rate == data_samples[i][1] + assert label == data_samples[i][2] + assert speaker_id == data_samples[i][3] + assert utterance_number == data_samples[i][4] + num_samples += 1 + + assert num_samples == len(data_samples) + + def testSpeechCommands_str(self): + dataset = speechcommands.SPEECHCOMMANDS(self.root_dir) + self._testSpeechCommands(dataset, self.samples) + + def testSpeechCommands_path(self): + dataset = speechcommands.SPEECHCOMMANDS(Path(self.root_dir)) + self._testSpeechCommands(dataset, self.samples) + + def testSpeechCommandsSubsetTrain(self): + dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="training") + self._testSpeechCommands(dataset, self.train_samples) + + def testSpeechCommandsSubsetValid(self): + dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="validation") + self._testSpeechCommands(dataset, self.valid_samples) + + def testSpeechCommandsSubsetTest(self): + dataset = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="testing") + self._testSpeechCommands(dataset, self.test_samples) + + def testSpeechCommandsSum(self): + dataset_all = speechcommands.SPEECHCOMMANDS(self.root_dir) + dataset_train = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="training") + dataset_valid = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="validation") + dataset_test = speechcommands.SPEECHCOMMANDS(self.root_dir, subset="testing") + + assert len(dataset_train) + len(dataset_valid) + len(dataset_test) == len(dataset_all) diff --git a/test/torchaudio_unittest/datasets/tedlium_test.py b/test/torchaudio_unittest/datasets/tedlium_test.py new file mode 100644 index 0000000000000000000000000000000000000000..00c3e1748e230fa6fbe4282509941a4d17bb18a2 --- /dev/null +++ b/test/torchaudio_unittest/datasets/tedlium_test.py @@ -0,0 +1,150 @@ +import os +import platform +from pathlib import Path + +from torchaudio_unittest.common_utils import ( + TempDirMixin, + TorchaudioTestCase, + get_whitenoise, + save_wav, + skipIfNoSox +) + +from torchaudio.datasets import tedlium + +# Used to generate a unique utterance for each dummy audio file +_UTTERANCES = [ + "AaronHuey_2010X 1 AaronHuey_2010X 0.0 2.0 script1\n", + "AaronHuey_2010X 1 AaronHuey_2010X 2.0 4.0 script2\n", + "AaronHuey_2010X 1 AaronHuey_2010X 4.0 6.0 script3\n", + "AaronHuey_2010X 1 AaronHuey_2010X 6.0 8.0 script4\n", + "AaronHuey_2010X 1 AaronHuey_2010X 8.0 10.0 script5\n", +] + +_PHONEME = [ + "a AH", + "a(2) EY", + "aachen AA K AH N", + "aad AE D", + "aaden EY D AH N", + "aadmi AE D M IY", + "aae EY EY", +] + + +def get_mock_dataset(dataset_dir): + """ + dataset_dir: directory of the mocked dataset + """ + mocked_samples = {} + os.makedirs(dataset_dir, exist_ok=True) + sample_rate = 16000 # 16kHz + seed = 0 + + for release in ["release1", "release2", "release3"]: + data = get_whitenoise(sample_rate=sample_rate, duration=10.00, n_channels=1, dtype="float32", seed=seed) + if release in ["release1", "release2"]: + release_dir = os.path.join( + dataset_dir, + tedlium._RELEASE_CONFIGS[release]["folder_in_archive"], + tedlium._RELEASE_CONFIGS[release]["subset"], + ) + else: + release_dir = os.path.join( + dataset_dir, + tedlium._RELEASE_CONFIGS[release]["folder_in_archive"], + tedlium._RELEASE_CONFIGS[release]["data_path"], + ) + os.makedirs(release_dir, exist_ok=True) + os.makedirs(os.path.join(release_dir, "stm"), exist_ok=True) # Subfolder for transcripts + os.makedirs(os.path.join(release_dir, "sph"), exist_ok=True) # Subfolder for audio files + filename = f"{release}.sph" + path = os.path.join(os.path.join(release_dir, "sph"), filename) + save_wav(path, data, sample_rate) + + trans_filename = f"{release}.stm" + trans_path = os.path.join(os.path.join(release_dir, "stm"), trans_filename) + with open(trans_path, "w") as f: + f.write("".join(_UTTERANCES)) + + dict_filename = f"{release}.dic" + dict_path = os.path.join(release_dir, dict_filename) + with open(dict_path, "w") as f: + f.write("\n".join(_PHONEME)) + + # Create a samples list to compare with + mocked_samples[release] = [] + for utterance in _UTTERANCES: + talk_id, _, speaker_id, start_time, end_time, identifier, transcript = utterance.split(" ", 6) + start_time = int(float(start_time)) * sample_rate + end_time = int(float(end_time)) * sample_rate + sample = ( + data[:, start_time:end_time], + sample_rate, + transcript, + talk_id, + speaker_id, + identifier, + ) + mocked_samples[release].append(sample) + seed += 1 + return mocked_samples + + +class Tedlium(TempDirMixin): + root_dir = None + samples = {} + + @classmethod + def setUpClass(cls): + cls.root_dir = cls.get_base_temp_dir() + cls.root_dir = dataset_dir = os.path.join(cls.root_dir, "tedlium") + cls.samples = get_mock_dataset(dataset_dir) + + def _test_tedlium(self, dataset, release): + num_samples = 0 + for i, (data, sample_rate, transcript, talk_id, speaker_id, identifier) in enumerate(dataset): + self.assertEqual(data, self.samples[release][i][0], atol=5e-5, rtol=1e-8) + assert sample_rate == self.samples[release][i][1] + assert transcript == self.samples[release][i][2] + assert talk_id == self.samples[release][i][3] + assert speaker_id == self.samples[release][i][4] + assert identifier == self.samples[release][i][5] + num_samples += 1 + + assert num_samples == len(self.samples[release]) + + dataset._dict_path = os.path.join(dataset._path, f"{release}.dic") + phoneme_dict = dataset.phoneme_dict + phoenemes = [f"{key} {' '.join(value)}" for key, value in phoneme_dict.items()] + assert phoenemes == _PHONEME + + def test_tedlium_release1_str(self): + release = "release1" + dataset = tedlium.TEDLIUM(self.root_dir, release=release) + self._test_tedlium(dataset, release) + + def test_tedlium_release1_path(self): + release = "release1" + dataset = tedlium.TEDLIUM(Path(self.root_dir), release=release) + self._test_tedlium(dataset, release) + + def test_tedlium_release2(self): + release = "release2" + dataset = tedlium.TEDLIUM(self.root_dir, release=release) + self._test_tedlium(dataset, release) + + def test_tedlium_release3(self): + release = "release3" + dataset = tedlium.TEDLIUM(self.root_dir, release=release) + self._test_tedlium(dataset, release) + + +class TestTedliumSoundfile(Tedlium, TorchaudioTestCase): + backend = "soundfile" + + +if platform.system() != "Windows": + @skipIfNoSox + class TestTedliumSoxIO(Tedlium, TorchaudioTestCase): + backend = "sox_io" diff --git a/test/torchaudio_unittest/datasets/utils_test.py b/test/torchaudio_unittest/datasets/utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8906991434df34799fb21733c7ca5cb2daba0390 --- /dev/null +++ b/test/torchaudio_unittest/datasets/utils_test.py @@ -0,0 +1,37 @@ +import torch +from torchaudio_unittest.common_utils import ( + TorchaudioTestCase, + TempDirMixin +) + +from torchaudio.datasets import utils as dataset_utils + + +class Dataset(torch.utils.data.Dataset): + def __getitem__(self, n): + sample_rate = 8000 + waveform = n * torch.ones(2, 256) + return waveform, sample_rate + + def __len__(self) -> int: + return 2 + + def __iter__(self): + for i in range(len(self)): + yield self[i] + + +class TestIterator(TorchaudioTestCase, TempDirMixin): + backend = 'default' + + def test_disckcache_iterator(self): + data = dataset_utils.diskcache_iterator(Dataset(), self.get_base_temp_dir()) + # Save + data[0] + # Load + data[0] + + def test_bg_iterator(self): + data = dataset_utils.bg_iterator(Dataset(), 5) + for _ in data: + pass diff --git a/test/torchaudio_unittest/datasets/vctk_test.py b/test/torchaudio_unittest/datasets/vctk_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4171c3c30719fddde9b18dc0c8e9ce1a8d3c6fe2 --- /dev/null +++ b/test/torchaudio_unittest/datasets/vctk_test.py @@ -0,0 +1,107 @@ +import os +from pathlib import Path + +from torchaudio.datasets import vctk + +from torchaudio_unittest.common_utils import ( + TempDirMixin, + TorchaudioTestCase, + get_whitenoise, + save_wav, + normalize_wav, +) + +# Used to generate a unique transcript for each dummy audio file +_TRANSCRIPT = [ + 'Please call Stella', + 'Ask her to bring these things', + 'with her from the store', + 'Six spoons of fresh snow peas, five thick slabs of blue cheese, and maybe a snack for her brother Bob', + 'We also need a small plastic snake and a big toy frog for the kids', + 'She can scoop these things into three red bags, and we will go meet her Wednesday at the train station', + 'When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow', + 'The rainbow is a division of white light into many beautiful colors', + 'These take the shape of a long round arch, with its path high above, and its two ends \ + apparently beyond the horizon', + 'There is, according to legend, a boiling pot of gold at one end' +] + + +def get_mock_dataset(root_dir): + """ + root_dir: root directory of the mocked data + """ + mocked_samples = [] + dataset_dir = os.path.join(root_dir, 'VCTK-Corpus-0.92') + os.makedirs(dataset_dir, exist_ok=True) + sample_rate = 48000 + seed = 0 + + for speaker in range(225, 230): + speaker_id = 'p' + str(speaker) + audio_dir = os.path.join(dataset_dir, 'wav48_silence_trimmed', speaker_id) + os.makedirs(audio_dir, exist_ok=True) + + file_dir = os.path.join(dataset_dir, 'txt', speaker_id) + os.makedirs(file_dir, exist_ok=True) + + for utterance_id in range(1, 11): + filename = f'{speaker_id}_{utterance_id:03d}_mic2' + audio_file_path = os.path.join(audio_dir, filename + '.wav') + + data = get_whitenoise( + sample_rate=sample_rate, + duration=0.01, + n_channels=1, + dtype='float32', + seed=seed + ) + save_wav(audio_file_path, data, sample_rate) + + txt_file_path = os.path.join(file_dir, filename[:-5] + '.txt') + transcript = _TRANSCRIPT[utterance_id - 1] + with open(txt_file_path, 'w') as f: + f.write(transcript) + + sample = ( + normalize_wav(data), + sample_rate, + transcript, + speaker_id, + utterance_id + ) + mocked_samples.append(sample) + seed += 1 + return mocked_samples + + +class TestVCTK(TempDirMixin, TorchaudioTestCase): + backend = 'default' + + root_dir = None + samples = [] + + @classmethod + def setUpClass(cls): + cls.root_dir = cls.get_base_temp_dir() + cls.samples = get_mock_dataset(cls.root_dir) + + def _test_vctk(self, dataset): + num_samples = 0 + for i, (data, sample_rate, transcript, speaker_id, utterance_id) in enumerate(dataset): + self.assertEqual(data, self.samples[i][0], atol=5e-5, rtol=1e-8) + assert sample_rate == self.samples[i][1] + assert transcript == self.samples[i][2] + assert speaker_id == self.samples[i][3] + assert int(utterance_id) == self.samples[i][4] + num_samples += 1 + + assert num_samples == len(self.samples) + + def test_vctk_str(self): + dataset = vctk.VCTK_092(self.root_dir, audio_ext=".wav") + self._test_vctk(dataset) + + def test_vctk_path(self): + dataset = vctk.VCTK_092(Path(self.root_dir), audio_ext=".wav") + self._test_vctk(dataset) diff --git a/test/torchaudio_unittest/datasets/yesno_test.py b/test/torchaudio_unittest/datasets/yesno_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2f4144985dad10f786fdb1362132ee4a89346822 --- /dev/null +++ b/test/torchaudio_unittest/datasets/yesno_test.py @@ -0,0 +1,67 @@ +import os +from pathlib import Path + +from torchaudio.datasets import yesno + +from torchaudio_unittest.common_utils import ( + TempDirMixin, + TorchaudioTestCase, + get_whitenoise, + save_wav, + normalize_wav, +) + + +def get_mock_data(root_dir, labels): + """ + root_dir: path + labels: list of labels + """ + mocked_data = [] + base_dir = os.path.join(root_dir, 'waves_yesno') + os.makedirs(base_dir, exist_ok=True) + for i, label in enumerate(labels): + filename = f'{"_".join(str(l) for l in label)}.wav' + path = os.path.join(base_dir, filename) + data = get_whitenoise(sample_rate=8000, duration=6, n_channels=1, dtype='int16', seed=i) + save_wav(path, data, 8000) + mocked_data.append(normalize_wav(data)) + return mocked_data + + +class TestYesNo(TempDirMixin, TorchaudioTestCase): + backend = 'default' + + root_dir = None + data = [] + labels = [ + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1], + [0, 1, 0, 1, 0, 1, 1, 0], + [1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1], + ] + + @classmethod + def setUpClass(cls): + cls.root_dir = cls.get_base_temp_dir() + cls.data = get_mock_data(cls.root_dir, cls.labels) + + def _test_yesno(self, dataset): + n_ite = 0 + for i, (waveform, sample_rate, label) in enumerate(dataset): + expected_label = self.labels[i] + expected_data = self.data[i] + self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8) + assert sample_rate == 8000 + assert label == expected_label + n_ite += 1 + assert n_ite == len(self.data) + + def test_yesno_str(self): + dataset = yesno.YESNO(self.root_dir) + self._test_yesno(dataset) + + def test_yesno_path(self): + dataset = yesno.YESNO(Path(self.root_dir)) + self._test_yesno(dataset) diff --git a/test/torchaudio_unittest/example/__init__.py b/test/torchaudio_unittest/example/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e4fa819890dccebb490576142563f729179adc6 --- /dev/null +++ b/test/torchaudio_unittest/example/__init__.py @@ -0,0 +1,8 @@ +import os +import sys + + +sys.path.append( + os.path.join( + os.path.dirname(__file__), + '..', '..', '..', 'examples')) diff --git a/test/torchaudio_unittest/example/souce_sepration/__init__.py b/test/torchaudio_unittest/example/souce_sepration/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/torchaudio_unittest/example/souce_sepration/metrics_test.py b/test/torchaudio_unittest/example/souce_sepration/metrics_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b7793b07ba2ad8188bd7c60f1db77f92a12f1956 --- /dev/null +++ b/test/torchaudio_unittest/example/souce_sepration/metrics_test.py @@ -0,0 +1,39 @@ +from itertools import product + +import torch +from torch.testing._internal.common_utils import TestCase +from parameterized import parameterized + +from . import sdr_reference +from source_separation.utils import metrics + + +class TestSDR(TestCase): + @parameterized.expand([(1, ), (2, ), (32, )]) + def test_sdr(self, batch_size): + """sdr produces the same result as the reference implementation""" + num_frames = 256 + + estimation = torch.rand(batch_size, num_frames) + origin = torch.rand(batch_size, num_frames) + + sdr_ref = sdr_reference.calc_sdr_torch(estimation, origin) + sdr = metrics.sdr(estimation.unsqueeze(1), origin.unsqueeze(1)).squeeze(1) + + self.assertEqual(sdr, sdr_ref) + + @parameterized.expand(list(product([1, 2, 32], [2, 3, 4, 5]))) + def test_sdr_pit(self, batch_size, num_sources): + """sdr_pit produces the same result as the reference implementation""" + num_frames = 256 + + estimation = torch.randn(batch_size, num_sources, num_frames) + origin = torch.randn(batch_size, num_sources, num_frames) + + estimation -= estimation.mean(axis=2, keepdim=True) + origin -= origin.mean(axis=2, keepdim=True) + + batch_sdr_ref = sdr_reference.batch_SDR_torch(estimation, origin) + batch_sdr = metrics.sdr_pit(estimation, origin) + + self.assertEqual(batch_sdr, batch_sdr_ref) diff --git a/test/torchaudio_unittest/example/souce_sepration/sdr_reference.py b/test/torchaudio_unittest/example/souce_sepration/sdr_reference.py new file mode 100644 index 0000000000000000000000000000000000000000..7652fab0e6e23f8c5b953faaa9c46b267d278845 --- /dev/null +++ b/test/torchaudio_unittest/example/souce_sepration/sdr_reference.py @@ -0,0 +1,98 @@ +"""Reference Implementation of SDR and PIT SDR. + +This module was taken from the following implementation + +https://github.com/naplab/Conv-TasNet/blob/e66d82a8f956a69749ec8a4ae382217faa097c5c/utility/sdr.py + +which was made available by Yi Luo under the following liscence, + +Creative Commons Attribution-NonCommercial-ShareAlike 3.0 United States License. + +The module was modified in the following manner; + - Remove the functions other than `calc_sdr_torch` and `batch_SDR_torch`, + - Remove the import statements required only for the removed functions. + - Add `# flake8: noqa` so as not to report any format issue on this module. + +The implementation of the retained functions and their formats are kept as-is. +""" + +# flake8: noqa + +import numpy as np +from itertools import permutations + +import torch + + +def calc_sdr_torch(estimation, origin, mask=None): + """ + batch-wise SDR caculation for one audio file on pytorch Variables. + estimation: (batch, nsample) + origin: (batch, nsample) + mask: optional, (batch, nsample), binary + """ + + if mask is not None: + origin = origin * mask + estimation = estimation * mask + + origin_power = torch.pow(origin, 2).sum(1, keepdim=True) + 1e-8 # (batch, 1) + + scale = torch.sum(origin*estimation, 1, keepdim=True) / origin_power # (batch, 1) + + est_true = scale * origin # (batch, nsample) + est_res = estimation - est_true # (batch, nsample) + + true_power = torch.pow(est_true, 2).sum(1) + res_power = torch.pow(est_res, 2).sum(1) + + return 10*torch.log10(true_power) - 10*torch.log10(res_power) # (batch, 1) + + +def batch_SDR_torch(estimation, origin, mask=None): + """ + batch-wise SDR caculation for multiple audio files. + estimation: (batch, nsource, nsample) + origin: (batch, nsource, nsample) + mask: optional, (batch, nsample), binary + """ + + batch_size_est, nsource_est, nsample_est = estimation.size() + batch_size_ori, nsource_ori, nsample_ori = origin.size() + + assert batch_size_est == batch_size_ori, "Estimation and original sources should have same shape." + assert nsource_est == nsource_ori, "Estimation and original sources should have same shape." + assert nsample_est == nsample_ori, "Estimation and original sources should have same shape." + + assert nsource_est < nsample_est, "Axis 1 should be the number of sources, and axis 2 should be the signal." + + batch_size = batch_size_est + nsource = nsource_est + nsample = nsample_est + + # zero mean signals + estimation = estimation - torch.mean(estimation, 2, keepdim=True).expand_as(estimation) + origin = origin - torch.mean(origin, 2, keepdim=True).expand_as(estimation) + + # possible permutations + perm = list(set(permutations(np.arange(nsource)))) + + # pair-wise SDR + SDR = torch.zeros((batch_size, nsource, nsource)).type(estimation.type()) + for i in range(nsource): + for j in range(nsource): + SDR[:,i,j] = calc_sdr_torch(estimation[:,i], origin[:,j], mask) + + # choose the best permutation + SDR_max = [] + SDR_perm = [] + for permute in perm: + sdr = [] + for idx in range(len(permute)): + sdr.append(SDR[:,idx,permute[idx]].view(batch_size,-1)) + sdr = torch.sum(torch.cat(sdr, 1), 1) + SDR_perm.append(sdr.view(batch_size, 1)) + SDR_perm = torch.cat(SDR_perm, 1) + SDR_max, _ = torch.max(SDR_perm, dim=1) + + return SDR_max / nsource diff --git a/test/torchaudio_unittest/example/souce_sepration/wsj0mix_test.py b/test/torchaudio_unittest/example/souce_sepration/wsj0mix_test.py new file mode 100644 index 0000000000000000000000000000000000000000..46927b182f595b53f4eba73e8f9833db3d112ba6 --- /dev/null +++ b/test/torchaudio_unittest/example/souce_sepration/wsj0mix_test.py @@ -0,0 +1,111 @@ +import os + +from torchaudio_unittest.common_utils import ( + TempDirMixin, + TorchaudioTestCase, + get_whitenoise, + save_wav, + normalize_wav, +) + +from source_separation.utils.dataset import wsj0mix + + +_FILENAMES = [ + "012c0207_1.9952_01cc0202_-1.9952.wav", + "01co0302_1.63_014c020q_-1.63.wav", + "01do0316_0.24011_205a0104_-0.24011.wav", + "01lc020x_1.1301_027o030r_-1.1301.wav", + "01mc0202_0.34056_205o0106_-0.34056.wav", + "01nc020t_0.53821_018o030w_-0.53821.wav", + "01po030f_2.2136_40ko031a_-2.2136.wav", + "01ra010o_2.4098_403a010f_-2.4098.wav", + "01xo030b_0.22377_016o031a_-0.22377.wav", + "02ac020x_0.68566_01ec020b_-0.68566.wav", + "20co010m_0.82801_019c0212_-0.82801.wav", + "20da010u_1.2483_017c0211_-1.2483.wav", + "20oo010d_1.0631_01ic020s_-1.0631.wav", + "20sc0107_2.0222_20fo010h_-2.0222.wav", + "20tc010f_0.051456_404a0110_-0.051456.wav", + "407c0214_1.1712_02ca0113_-1.1712.wav", + "40ao030w_2.4697_20vc010a_-2.4697.wav", + "40pa0101_1.1087_40ea0107_-1.1087.wav", +] + + +def _mock_dataset(root_dir, num_speaker): + dirnames = ["mix"] + [f"s{i+1}" for i in range(num_speaker)] + for dirname in dirnames: + os.makedirs(os.path.join(root_dir, dirname), exist_ok=True) + + seed = 0 + sample_rate = 8000 + expected = [] + for filename in _FILENAMES: + mix = None + src = [] + for dirname in dirnames: + waveform = get_whitenoise( + sample_rate=8000, duration=1, n_channels=1, dtype="int16", seed=seed + ) + seed += 1 + + path = os.path.join(root_dir, dirname, filename) + save_wav(path, waveform, sample_rate) + waveform = normalize_wav(waveform) + + if dirname == "mix": + mix = waveform + else: + src.append(waveform) + expected.append((sample_rate, mix, src)) + return expected + + +class TestWSJ0Mix2(TempDirMixin, TorchaudioTestCase): + backend = "default" + root_dir = None + expected = None + + @classmethod + def setUpClass(cls): + cls.root_dir = cls.get_base_temp_dir() + cls.expected = _mock_dataset(cls.root_dir, 2) + + def test_wsj0mix(self): + dataset = wsj0mix.WSJ0Mix(self.root_dir, num_speakers=2, sample_rate=8000) + + n_ite = 0 + for i, sample in enumerate(dataset): + (_, sample_mix, sample_src) = sample + (_, expected_mix, expected_src) = self.expected[i] + self.assertEqual(sample_mix, expected_mix, atol=5e-5, rtol=1e-8) + self.assertEqual(sample_src[0], expected_src[0], atol=5e-5, rtol=1e-8) + self.assertEqual(sample_src[1], expected_src[1], atol=5e-5, rtol=1e-8) + n_ite += 1 + assert n_ite == len(self.expected) + + +class TestWSJ0Mix3(TempDirMixin, TorchaudioTestCase): + backend = "default" + root_dir = None + expected = None + + @classmethod + def setUpClass(cls): + cls.root_dir = cls.get_base_temp_dir() + cls.expected = _mock_dataset(cls.root_dir, 3) + + def test_wsj0mix(self): + dataset = wsj0mix.WSJ0Mix(self.root_dir, num_speakers=3, sample_rate=8000) + + n_ite = 0 + for i, sample in enumerate(dataset): + (_, sample_mix, sample_src) = sample + (_, expected_mix, expected_src) = self.expected[i] + self.assertEqual(sample_mix, expected_mix, atol=5e-5, rtol=1e-8) + self.assertEqual(sample_src[0], expected_src[0], atol=5e-5, rtol=1e-8) + self.assertEqual(sample_src[1], expected_src[1], atol=5e-5, rtol=1e-8) + self.assertEqual(sample_src[2], expected_src[2], atol=5e-5, rtol=1e-8) + n_ite += 1 + assert n_ite == len(self.expected) diff --git a/test/torchaudio_unittest/example/tacotron2/__init__.py b/test/torchaudio_unittest/example/tacotron2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/torchaudio_unittest/example/tacotron2/tacotron2_loss_cpu_test.py b/test/torchaudio_unittest/example/tacotron2/tacotron2_loss_cpu_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f4d2e38987401b687f293db17c3d943f88a9cef9 --- /dev/null +++ b/test/torchaudio_unittest/example/tacotron2/tacotron2_loss_cpu_test.py @@ -0,0 +1,23 @@ +import torch + +from .tacotron2_loss_impl import ( + Tacotron2LossShapeTests, + Tacotron2LossTorchscriptTests, + Tacotron2LossGradcheckTests, +) +from torchaudio_unittest.common_utils import PytorchTestCase + + +class TestTacotron2LossShapeFloat32CPU(Tacotron2LossShapeTests, PytorchTestCase): + dtype = torch.float32 + device = torch.device("cpu") + + +class TestTacotron2TorchsciptFloat32CPU(Tacotron2LossTorchscriptTests, PytorchTestCase): + dtype = torch.float32 + device = torch.device("cpu") + + +class TestTacotron2GradcheckFloat64CPU(Tacotron2LossGradcheckTests, PytorchTestCase): + dtype = torch.float64 # gradcheck needs a higher numerical accuracy + device = torch.device("cpu") diff --git a/test/torchaudio_unittest/example/tacotron2/tacotron2_loss_gpu_test.py b/test/torchaudio_unittest/example/tacotron2/tacotron2_loss_gpu_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9c1ae252c4dbbad497dfad456ce31516af34009a --- /dev/null +++ b/test/torchaudio_unittest/example/tacotron2/tacotron2_loss_gpu_test.py @@ -0,0 +1,26 @@ +import torch + +from .tacotron2_loss_impl import ( + Tacotron2LossShapeTests, + Tacotron2LossTorchscriptTests, + Tacotron2LossGradcheckTests, +) +from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase + + +@skipIfNoCuda +class TestTacotron2LossShapeFloat32CUDA(PytorchTestCase, Tacotron2LossShapeTests): + dtype = torch.float32 + device = torch.device("cuda") + + +@skipIfNoCuda +class TestTacotron2TorchsciptFloat32CUDA(PytorchTestCase, Tacotron2LossTorchscriptTests): + dtype = torch.float32 + device = torch.device("cuda") + + +@skipIfNoCuda +class TestTacotron2GradcheckFloat64CUDA(PytorchTestCase, Tacotron2LossGradcheckTests): + dtype = torch.float64 # gradcheck needs a higher numerical accuracy + device = torch.device("cuda") diff --git a/test/torchaudio_unittest/example/tacotron2/tacotron2_loss_impl.py b/test/torchaudio_unittest/example/tacotron2/tacotron2_loss_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..848126e69604afdfbff1328bf8220b4c334f6ffb --- /dev/null +++ b/test/torchaudio_unittest/example/tacotron2/tacotron2_loss_impl.py @@ -0,0 +1,111 @@ +import torch +from torch.autograd import gradcheck, gradgradcheck + +from pipeline_tacotron2.loss import Tacotron2Loss +from torchaudio_unittest.common_utils import ( + TestBaseMixin, + torch_script, +) + + +class Tacotron2LossInputMixin(TestBaseMixin): + + def _get_inputs(self, n_mel=80, n_batch=16, max_mel_specgram_length=300): + mel_specgram = torch.rand( + n_batch, n_mel, max_mel_specgram_length, dtype=self.dtype, device=self.device + ) + mel_specgram_postnet = torch.rand( + n_batch, n_mel, max_mel_specgram_length, dtype=self.dtype, device=self.device + ) + gate_out = torch.rand(n_batch, dtype=self.dtype, device=self.device) + truth_mel_specgram = torch.rand( + n_batch, n_mel, max_mel_specgram_length, dtype=self.dtype, device=self.device + ) + truth_gate_out = torch.rand(n_batch, dtype=self.dtype, device=self.device) + + truth_mel_specgram.requires_grad = False + truth_gate_out.requires_grad = False + + return ( + mel_specgram, + mel_specgram_postnet, + gate_out, + truth_mel_specgram, + truth_gate_out, + ) + + +class Tacotron2LossShapeTests(Tacotron2LossInputMixin): + + def test_tacotron2_loss_shape(self): + """Validate the output shape of Tacotron2Loss.""" + n_batch = 16 + + ( + mel_specgram, + mel_specgram_postnet, + gate_out, + truth_mel_specgram, + truth_gate_out, + ) = self._get_inputs(n_batch=n_batch) + + mel_loss, mel_postnet_loss, gate_loss = Tacotron2Loss()( + (mel_specgram, mel_specgram_postnet, gate_out), + (truth_mel_specgram, truth_gate_out) + ) + + self.assertEqual(mel_loss.size(), torch.Size([])) + self.assertEqual(mel_postnet_loss.size(), torch.Size([])) + self.assertEqual(gate_loss.size(), torch.Size([])) + + +class Tacotron2LossTorchscriptTests(Tacotron2LossInputMixin): + + def _assert_torchscript_consistency(self, fn, tensors): + ts_func = torch_script(fn) + + output = fn(tensors[:3], tensors[3:]) + ts_output = ts_func(tensors[:3], tensors[3:]) + + self.assertEqual(ts_output, output) + + def test_tacotron2_loss_torchscript_consistency(self): + """Validate the torchscript consistency of Tacotron2Loss.""" + + loss_fn = Tacotron2Loss() + self._assert_torchscript_consistency(loss_fn, self._get_inputs()) + + +class Tacotron2LossGradcheckTests(Tacotron2LossInputMixin): + + def test_tacotron2_loss_gradcheck(self): + """Performing gradient check on Tacotron2Loss.""" + ( + mel_specgram, + mel_specgram_postnet, + gate_out, + truth_mel_specgram, + truth_gate_out, + ) = self._get_inputs() + + mel_specgram.requires_grad_(True) + mel_specgram_postnet.requires_grad_(True) + gate_out.requires_grad_(True) + + def _fn(mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out): + loss_fn = Tacotron2Loss() + return loss_fn( + (mel_specgram, mel_specgram_postnet, gate_out), + (truth_mel_specgram, truth_gate_out), + ) + + gradcheck( + _fn, + (mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out), + fast_mode=True, + ) + gradgradcheck( + _fn, + (mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out), + fast_mode=True, + ) diff --git a/test/torchaudio_unittest/example/tacotron2/test_text_preprocessing.py b/test/torchaudio_unittest/example/tacotron2/test_text_preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..8da02de88ec8852fb2a580d3ea97b220a71b5d66 --- /dev/null +++ b/test/torchaudio_unittest/example/tacotron2/test_text_preprocessing.py @@ -0,0 +1,97 @@ +from parameterized import parameterized + +from torchaudio._internal.module_utils import is_module_available +from torchaudio_unittest.common_utils import TorchaudioTestCase, skipIfNoModule + +if is_module_available("unidecode") and is_module_available("inflect"): + from pipeline_tacotron2.text.text_preprocessing import text_to_sequence + from pipeline_tacotron2.text.numbers import ( + _remove_commas, + _expand_pounds, + _expand_dollars, + _expand_decimal_point, + _expand_ordinal, + _expand_number, + ) + + +@skipIfNoModule("unidecode") +@skipIfNoModule("inflect") +class TestTextPreprocessor(TorchaudioTestCase): + + @parameterized.expand( + [ + ["dr. Strange?", [15, 26, 14, 31, 26, 29, 11, 30, 31, 29, 12, 25, 18, 16, 10]], + ["ML, is fun.", [24, 23, 6, 11, 20, 30, 11, 17, 32, 25, 7]], + ["I love torchaudio!", [20, 11, 23, 26, 33, 16, 11, 31, 26, 29, 14, 19, 12, 32, 15, 20, 26, 2]], + # 'one thousand dollars, twenty cents' + ["$1,000.20", [26, 25, 16, 11, 31, 19, 26, 32, 30, 12, 25, 15, 11, 15, 26, 23, 23, + 12, 29, 30, 6, 11, 31, 34, 16, 25, 31, 36, 11, 14, 16, 25, 31, 30]], + ] + ) + def test_text_to_sequence(self, sent, seq): + + assert (text_to_sequence(sent) == seq) + + @parameterized.expand( + [ + ["He, she, and I have $1,000", "He, she, and I have $1000"], + ] + ) + def test_remove_commas(self, sent, truth): + + assert (_remove_commas(sent) == truth) + + @parameterized.expand( + [ + ["He, she, and I have £1000", "He, she, and I have 1000 pounds"], + ] + ) + def test_expand_pounds(self, sent, truth): + + assert (_expand_pounds(sent) == truth) + + @parameterized.expand( + [ + ["He, she, and I have $1000", "He, she, and I have 1000 dollars"], + ["He, she, and I have $3000.01", "He, she, and I have 3000 dollars, 1 cent"], + ["He has $500.20 and she has $1000.50.", + "He has 500 dollars, 20 cents and she has 1000 dollars, 50 cents."], + ] + ) + def test_expand_dollars(self, sent, truth): + + assert (_expand_dollars(sent) == truth) + + @parameterized.expand( + [ + ["1000.20", "1000 point 20"], + ["1000.1", "1000 point 1"], + ] + ) + def test_expand_decimal_point(self, sent, truth): + + assert (_expand_decimal_point(sent) == truth) + + @parameterized.expand( + [ + ["21st centry", "twenty-first centry"], + ["20th centry", "twentieth centry"], + ["2nd place.", "second place."], + ] + ) + def test_expand_ordinal(self, sent, truth): + + assert (_expand_ordinal(sent) == truth) + _expand_ordinal, + + @parameterized.expand( + [ + ["100020 dollars.", "one hundred thousand twenty dollars."], + ["1234567890!", "one billion, two hundred thirty-four million, " + "five hundred sixty-seven thousand, eight hundred ninety!"], + ] + ) + def test_expand_number(self, sent, truth): + + assert (_expand_number(sent) == truth) diff --git a/test/torchaudio_unittest/functional/__init__.py b/test/torchaudio_unittest/functional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/torchaudio_unittest/functional/autograd_cpu_test.py b/test/torchaudio_unittest/functional/autograd_cpu_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a34823a76f0a591bbac97d565856ad7433054ff5 --- /dev/null +++ b/test/torchaudio_unittest/functional/autograd_cpu_test.py @@ -0,0 +1,13 @@ +import torch +from .autograd_impl import Autograd, AutogradFloat32 +from torchaudio_unittest import common_utils + + +class TestAutogradLfilterCPU(Autograd, common_utils.PytorchTestCase): + dtype = torch.float64 + device = torch.device('cpu') + + +class TestAutogradRNNTCPU(AutogradFloat32, common_utils.PytorchTestCase): + dtype = torch.float32 + device = torch.device('cpu') diff --git a/test/torchaudio_unittest/functional/autograd_cuda_test.py b/test/torchaudio_unittest/functional/autograd_cuda_test.py new file mode 100644 index 0000000000000000000000000000000000000000..341c575d87aadc0f6cdc59c9daed3eb5d94e84c4 --- /dev/null +++ b/test/torchaudio_unittest/functional/autograd_cuda_test.py @@ -0,0 +1,15 @@ +import torch +from .autograd_impl import Autograd, AutogradFloat32 +from torchaudio_unittest import common_utils + + +@common_utils.skipIfNoCuda +class TestAutogradLfilterCUDA(Autograd, common_utils.PytorchTestCase): + dtype = torch.float64 + device = torch.device('cuda') + + +@common_utils.skipIfNoCuda +class TestAutogradRNNTCUDA(AutogradFloat32, common_utils.PytorchTestCase): + dtype = torch.float32 + device = torch.device('cuda') diff --git a/test/torchaudio_unittest/functional/autograd_impl.py b/test/torchaudio_unittest/functional/autograd_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..ecd27722f71a94005c5f093a8c7bddb7b10c3528 --- /dev/null +++ b/test/torchaudio_unittest/functional/autograd_impl.py @@ -0,0 +1,269 @@ +from typing import Callable, Tuple +from functools import partial +import torch +from parameterized import parameterized +from torch import Tensor +import torchaudio.functional as F +from torch.autograd import gradcheck, gradgradcheck +from torchaudio_unittest.common_utils import ( + TestBaseMixin, + get_whitenoise, + rnnt_utils, +) + + +class Autograd(TestBaseMixin): + def assert_grad( + self, + transform: Callable[..., Tensor], + inputs: Tuple[torch.Tensor], + *, + enable_all_grad: bool = True, + ): + inputs_ = [] + for i in inputs: + if torch.is_tensor(i): + i = i.to(dtype=self.dtype, device=self.device) + if enable_all_grad: + i.requires_grad = True + inputs_.append(i) + assert gradcheck(transform, inputs_) + assert gradgradcheck(transform, inputs_) + + def test_lfilter_x(self): + torch.random.manual_seed(2434) + x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2) + a = torch.tensor([0.7, 0.2, 0.6]) + b = torch.tensor([0.4, 0.2, 0.9]) + x.requires_grad = True + self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False) + + def test_lfilter_a(self): + torch.random.manual_seed(2434) + x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2) + a = torch.tensor([0.7, 0.2, 0.6]) + b = torch.tensor([0.4, 0.2, 0.9]) + a.requires_grad = True + self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False) + + def test_lfilter_b(self): + torch.random.manual_seed(2434) + x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2) + a = torch.tensor([0.7, 0.2, 0.6]) + b = torch.tensor([0.4, 0.2, 0.9]) + b.requires_grad = True + self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False) + + def test_lfilter_all_inputs(self): + torch.random.manual_seed(2434) + x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2) + a = torch.tensor([0.7, 0.2, 0.6]) + b = torch.tensor([0.4, 0.2, 0.9]) + self.assert_grad(F.lfilter, (x, a, b)) + + def test_lfilter_filterbanks(self): + torch.random.manual_seed(2434) + x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=3) + a = torch.tensor([[0.7, 0.2, 0.6], + [0.8, 0.2, 0.9]]) + b = torch.tensor([[0.4, 0.2, 0.9], + [0.7, 0.2, 0.6]]) + self.assert_grad(partial(F.lfilter, batching=False), (x, a, b)) + + def test_lfilter_batching(self): + torch.random.manual_seed(2434) + x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2) + a = torch.tensor([[0.7, 0.2, 0.6], + [0.8, 0.2, 0.9]]) + b = torch.tensor([[0.4, 0.2, 0.9], + [0.7, 0.2, 0.6]]) + self.assert_grad(F.lfilter, (x, a, b)) + + def test_filtfilt_a(self): + torch.random.manual_seed(2434) + x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2) + a = torch.tensor([0.7, 0.2, 0.6]) + b = torch.tensor([0.4, 0.2, 0.9]) + a.requires_grad = True + self.assert_grad(F.filtfilt, (x, a, b), enable_all_grad=False) + + def test_filtfilt_b(self): + torch.random.manual_seed(2434) + x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2) + a = torch.tensor([0.7, 0.2, 0.6]) + b = torch.tensor([0.4, 0.2, 0.9]) + b.requires_grad = True + self.assert_grad(F.filtfilt, (x, a, b), enable_all_grad=False) + + def test_filtfilt_all_inputs(self): + torch.random.manual_seed(2434) + x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2) + a = torch.tensor([0.7, 0.2, 0.6]) + b = torch.tensor([0.4, 0.2, 0.9]) + self.assert_grad(F.filtfilt, (x, a, b)) + + def test_filtfilt_batching(self): + torch.random.manual_seed(2434) + x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2) + a = torch.tensor([[0.7, 0.2, 0.6], + [0.8, 0.2, 0.9]]) + b = torch.tensor([[0.4, 0.2, 0.9], + [0.7, 0.2, 0.6]]) + self.assert_grad(F.filtfilt, (x, a, b)) + + def test_biquad(self): + torch.random.manual_seed(2434) + x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1) + a = torch.tensor([0.7, 0.2, 0.6]) + b = torch.tensor([0.4, 0.2, 0.9]) + self.assert_grad(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2])) + + @parameterized.expand([ + (800, 0.7, True), + (800, 0.7, False), + ]) + def test_band_biquad(self, central_freq, Q, noise): + torch.random.manual_seed(2434) + sr = 22050 + x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1) + central_freq = torch.tensor(central_freq) + Q = torch.tensor(Q) + self.assert_grad(F.band_biquad, (x, sr, central_freq, Q, noise)) + + @parameterized.expand([ + (800, 0.7, 10), + (800, 0.7, -10), + ]) + def test_bass_biquad(self, central_freq, Q, gain): + torch.random.manual_seed(2434) + sr = 22050 + x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1) + central_freq = torch.tensor(central_freq) + Q = torch.tensor(Q) + gain = torch.tensor(gain) + self.assert_grad(F.bass_biquad, (x, sr, gain, central_freq, Q)) + + @parameterized.expand([ + (3000, 0.7, 10), + (3000, 0.7, -10), + + ]) + def test_treble_biquad(self, central_freq, Q, gain): + torch.random.manual_seed(2434) + sr = 22050 + x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1) + central_freq = torch.tensor(central_freq) + Q = torch.tensor(Q) + gain = torch.tensor(gain) + self.assert_grad(F.treble_biquad, (x, sr, gain, central_freq, Q)) + + @parameterized.expand([ + (800, 0.7, ), + ]) + def test_allpass_biquad(self, central_freq, Q): + torch.random.manual_seed(2434) + sr = 22050 + x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1) + central_freq = torch.tensor(central_freq) + Q = torch.tensor(Q) + self.assert_grad(F.allpass_biquad, (x, sr, central_freq, Q)) + + @parameterized.expand([ + (800, 0.7, ), + ]) + def test_lowpass_biquad(self, cutoff_freq, Q): + torch.random.manual_seed(2434) + sr = 22050 + x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1) + cutoff_freq = torch.tensor(cutoff_freq) + Q = torch.tensor(Q) + self.assert_grad(F.lowpass_biquad, (x, sr, cutoff_freq, Q)) + + @parameterized.expand([ + (800, 0.7, ), + ]) + def test_highpass_biquad(self, cutoff_freq, Q): + torch.random.manual_seed(2434) + sr = 22050 + x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1) + cutoff_freq = torch.tensor(cutoff_freq) + Q = torch.tensor(Q) + self.assert_grad(F.highpass_biquad, (x, sr, cutoff_freq, Q)) + + @parameterized.expand([ + (800, 0.7, True), + (800, 0.7, False), + ]) + def test_bandpass_biquad(self, central_freq, Q, const_skirt_gain): + torch.random.manual_seed(2434) + sr = 22050 + x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1) + central_freq = torch.tensor(central_freq) + Q = torch.tensor(Q) + self.assert_grad(F.bandpass_biquad, (x, sr, central_freq, Q, const_skirt_gain)) + + @parameterized.expand([ + (800, 0.7, 10), + (800, 0.7, -10), + ]) + def test_equalizer_biquad(self, central_freq, Q, gain): + torch.random.manual_seed(2434) + sr = 22050 + x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1) + central_freq = torch.tensor(central_freq) + Q = torch.tensor(Q) + gain = torch.tensor(gain) + self.assert_grad(F.equalizer_biquad, (x, sr, central_freq, gain, Q)) + + @parameterized.expand([ + (800, 0.7, ), + ]) + def test_bandreject_biquad(self, central_freq, Q): + torch.random.manual_seed(2434) + sr = 22050 + x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1) + central_freq = torch.tensor(central_freq) + Q = torch.tensor(Q) + self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q)) + + +class AutogradFloat32(TestBaseMixin): + def assert_grad( + self, + transform: Callable[..., Tensor], + inputs: Tuple[torch.Tensor], + enable_all_grad: bool = True, + ): + inputs_ = [] + for i in inputs: + if torch.is_tensor(i): + i = i.to(dtype=self.dtype, device=self.device) + if enable_all_grad: + i.requires_grad = True + inputs_.append(i) + # gradcheck with float32 requires higher atol and epsilon + assert gradcheck(transform, inputs, eps=1e-3, atol=1e-3, nondet_tol=0.) + + @parameterized.expand([ + (rnnt_utils.get_B1_T10_U3_D4_data, ), + (rnnt_utils.get_B2_T4_U3_D3_data, ), + (rnnt_utils.get_B1_T2_U3_D5_data, ), + ]) + def test_rnnt_loss(self, data_func): + def get_data(data_func, device): + data = data_func() + if type(data) == tuple: + data = data[0] + return data + + data = get_data(data_func, self.device) + inputs = ( + data["logits"].to(torch.float32), # logits + data["targets"], # targets + data["logit_lengths"], # logit_lengths + data["target_lengths"], # target_lengths + data["blank"], # blank + -1, # clamp + ) + + self.assert_grad(F.rnnt_loss, inputs, enable_all_grad=False) diff --git a/test/torchaudio_unittest/functional/batch_consistency_test.py b/test/torchaudio_unittest/functional/batch_consistency_test.py new file mode 100644 index 0000000000000000000000000000000000000000..042bfe52a124cd38cbf01abd543a98e3ed172da8 --- /dev/null +++ b/test/torchaudio_unittest/functional/batch_consistency_test.py @@ -0,0 +1,249 @@ +"""Test numerical consistency among single input and batched input.""" +import itertools +import math + +from parameterized import parameterized, parameterized_class +import torch +import torchaudio.functional as F + +from torchaudio_unittest import common_utils + + +def _name_from_args(func, _, params): + """Return a parameterized test name, based on parameter values.""" + return "{}_{}".format( + func.__name__, + "_".join(str(arg) for arg in params.args)) + + +@parameterized_class([ + # Single-item batch isolates problems that come purely from adding a + # dimension (rather than processing multiple items) + {"batch_size": 1}, + {"batch_size": 3}, +]) +class TestFunctional(common_utils.TorchaudioTestCase): + """Test functions defined in `functional` module""" + backend = 'default' + + def assert_batch_consistency( + self, functional, batch, *args, atol=1e-8, rtol=1e-5, seed=42, + **kwargs): + n = batch.size(0) + + # Compute items separately, then batch the result + torch.random.manual_seed(seed) + items_input = batch.clone() + items_result = torch.stack([ + functional(items_input[i], *args, **kwargs) for i in range(n) + ]) + + # Batch the input and run + torch.random.manual_seed(seed) + batch_input = batch.clone() + batch_result = functional(batch_input, *args, **kwargs) + + self.assertEqual(items_input, batch_input, rtol=rtol, atol=atol) + self.assertEqual(items_result, batch_result, rtol=rtol, atol=atol) + + def test_griffinlim(self): + n_fft = 400 + ws = 400 + hop = 200 + window = torch.hann_window(ws) + power = 2 + momentum = 0.99 + n_iter = 32 + length = 1000 + torch.random.manual_seed(0) + batch = torch.rand(self.batch_size, 1, 201, 6) + self.assert_batch_consistency( + F.griffinlim, batch, window, n_fft, hop, ws, power, + n_iter, momentum, length, 0, atol=5e-5) + + @parameterized.expand(list(itertools.product( + [8000, 16000, 44100], + [1, 2], + )), name_func=_name_from_args) + def test_detect_pitch_frequency(self, sample_rate, n_channels): + # Use different frequencies to ensure each item in the batch returns a + # different answer. + torch.manual_seed(0) + frequencies = torch.randint(100, 1000, [self.batch_size]) + waveforms = torch.stack([ + common_utils.get_sinusoid( + frequency=frequency, sample_rate=sample_rate, + n_channels=n_channels, duration=5) + for frequency in frequencies + ]) + self.assert_batch_consistency( + F.detect_pitch_frequency, waveforms, sample_rate) + + def test_amplitude_to_DB(self): + torch.manual_seed(0) + spec = torch.rand(self.batch_size, 2, 100, 100) * 200 + + amplitude_mult = 20. + amin = 1e-10 + ref = 1.0 + db_mult = math.log10(max(amin, ref)) + + # Test with & without a `top_db` clamp + self.assert_batch_consistency( + F.amplitude_to_DB, spec, amplitude_mult, + amin, db_mult, top_db=None) + self.assert_batch_consistency( + F.amplitude_to_DB, spec, amplitude_mult, + amin, db_mult, top_db=40.) + + def test_amplitude_to_DB_itemwise_clamps(self): + """Ensure that the clamps are separate for each spectrogram in a batch. + + The clamp was determined per-batch in a prior implementation, which + meant it was determined by the loudest item, thus items weren't + independent. See: + + https://github.com/pytorch/audio/issues/994 + + """ + amplitude_mult = 20. + amin = 1e-10 + ref = 1.0 + db_mult = math.log10(max(amin, ref)) + top_db = 20. + + # Make a batch of noise + torch.manual_seed(0) + spec = torch.rand([2, 2, 100, 100]) * 200 + # Make one item blow out the other + spec[0] += 50 + + batchwise_dbs = F.amplitude_to_DB(spec, amplitude_mult, amin, + db_mult, top_db=top_db) + itemwise_dbs = torch.stack([ + F.amplitude_to_DB(item, amplitude_mult, amin, + db_mult, top_db=top_db) + for item in spec + ]) + + self.assertEqual(batchwise_dbs, itemwise_dbs) + + def test_amplitude_to_DB_not_channelwise_clamps(self): + """Check that clamps are applied per-item, not per channel.""" + amplitude_mult = 20. + amin = 1e-10 + ref = 1.0 + db_mult = math.log10(max(amin, ref)) + top_db = 40. + + torch.manual_seed(0) + spec = torch.rand([1, 2, 100, 100]) * 200 + # Make one channel blow out the other + spec[:, 0] += 50 + + specwise_dbs = F.amplitude_to_DB(spec, amplitude_mult, amin, + db_mult, top_db=top_db) + channelwise_dbs = torch.stack([ + F.amplitude_to_DB(spec[:, i], amplitude_mult, amin, + db_mult, top_db=top_db) + for i in range(spec.size(-3)) + ]) + + # Just check channelwise gives a different answer. + difference = (specwise_dbs - channelwise_dbs).abs() + assert (difference >= 1e-5).any() + + def test_contrast(self): + torch.random.manual_seed(0) + waveforms = torch.rand(self.batch_size, 2, 100) - 0.5 + self.assert_batch_consistency( + F.contrast, waveforms, enhancement_amount=80.) + + def test_dcshift(self): + torch.random.manual_seed(0) + waveforms = torch.rand(self.batch_size, 2, 100) - 0.5 + self.assert_batch_consistency( + F.dcshift, waveforms, shift=0.5, limiter_gain=0.05) + + def test_overdrive(self): + torch.random.manual_seed(0) + waveforms = torch.rand(self.batch_size, 2, 100) - 0.5 + self.assert_batch_consistency( + F.overdrive, waveforms, gain=45, colour=30) + + def test_phaser(self): + sample_rate = 44100 + n_channels = 2 + waveform = common_utils.get_whitenoise( + sample_rate=sample_rate, n_channels=self.batch_size * n_channels, + duration=1) + batch = waveform.view(self.batch_size, n_channels, waveform.size(-1)) + self.assert_batch_consistency(F.phaser, batch, sample_rate) + + def test_flanger(self): + torch.random.manual_seed(0) + waveforms = torch.rand(self.batch_size, 2, 100) - 0.5 + sample_rate = 44100 + self.assert_batch_consistency(F.flanger, waveforms, sample_rate) + + @parameterized.expand(list(itertools.product( + [True, False], # center + [True, False], # norm_vars + )), name_func=_name_from_args) + def test_sliding_window_cmn(self, center, norm_vars): + torch.manual_seed(0) + spectrogram = torch.rand(self.batch_size, 2, 1024, 1024) * 200 + self.assert_batch_consistency( + F.sliding_window_cmn, spectrogram, center=center, + norm_vars=norm_vars) + + @parameterized.expand([("sinc_interpolation"), ("kaiser_window")]) + def test_resample_waveform(self, resampling_method): + num_channels = 3 + sr = 16000 + new_sr = sr // 2 + multi_sound = common_utils.get_whitenoise(sample_rate=sr, n_channels=num_channels, duration=0.5,) + + self.assert_batch_consistency( + F.resample, multi_sound, orig_freq=sr, new_freq=new_sr, + resampling_method=resampling_method, rtol=1e-4, atol=1e-7) + + @common_utils.skipIfNoKaldi + def test_compute_kaldi_pitch(self): + sample_rate = 44100 + n_channels = 2 + waveform = common_utils.get_whitenoise( + sample_rate=sample_rate, n_channels=self.batch_size * n_channels) + batch = waveform.view(self.batch_size, n_channels, waveform.size(-1)) + self.assert_batch_consistency( + F.compute_kaldi_pitch, batch, sample_rate=sample_rate) + + def test_lfilter(self): + signal_length = 2048 + torch.manual_seed(2434) + x = torch.randn(self.batch_size, signal_length) + a = torch.rand(self.batch_size, 3) + b = torch.rand(self.batch_size, 3) + + batchwise_output = F.lfilter(x, a, b, batching=True) + itemwise_output = torch.stack([ + F.lfilter(x[i], a[i], b[i]) + for i in range(self.batch_size) + ]) + + self.assertEqual(batchwise_output, itemwise_output) + + def test_filtfilt(self): + signal_length = 2048 + torch.manual_seed(2434) + x = torch.randn(self.batch_size, signal_length) + a = torch.rand(self.batch_size, 3) + b = torch.rand(self.batch_size, 3) + + batchwise_output = F.filtfilt(x, a, b) + itemwise_output = torch.stack([ + F.filtfilt(x[i], a[i], b[i]) + for i in range(self.batch_size) + ]) + + self.assertEqual(batchwise_output, itemwise_output) diff --git a/test/torchaudio_unittest/functional/functional_cpu_test.py b/test/torchaudio_unittest/functional/functional_cpu_test.py new file mode 100644 index 0000000000000000000000000000000000000000..520ff86b20c1b83cc015d1748d22ae086dbeb4fa --- /dev/null +++ b/test/torchaudio_unittest/functional/functional_cpu_test.py @@ -0,0 +1,63 @@ +import torch +import torchaudio.functional as F +import unittest +from parameterized import parameterized + +from torchaudio_unittest.common_utils import PytorchTestCase, TorchaudioTestCase, skipIfNoSox +from .functional_impl import Functional, FunctionalCPUOnly + + +class TestFunctionalFloat32(Functional, FunctionalCPUOnly, PytorchTestCase): + dtype = torch.float32 + device = torch.device('cpu') + + @unittest.expectedFailure + def test_lfilter_9th_order_filter_stability(self): + super().test_lfilter_9th_order_filter_stability() + + +class TestFunctionalFloat64(Functional, PytorchTestCase): + dtype = torch.float64 + device = torch.device('cpu') + + +@skipIfNoSox +class TestApplyCodec(TorchaudioTestCase): + backend = "sox_io" + + def _smoke_test(self, format, compression, check_num_frames): + """ + The purpose of this test suite is to verify that apply_codec functionalities do not exhibit + abnormal behaviors. + """ + torch.random.manual_seed(42) + sample_rate = 8000 + num_frames = 3 * sample_rate + num_channels = 2 + waveform = torch.rand(num_channels, num_frames) + + augmented = F.apply_codec(waveform, + sample_rate, + format, + True, + compression + ) + assert augmented.dtype == waveform.dtype + assert augmented.shape[0] == num_channels + if check_num_frames: + assert augmented.shape[1] == num_frames + + def test_wave(self): + self._smoke_test("wav", compression=None, check_num_frames=True) + + @parameterized.expand([(96,), (128,), (160,), (192,), (224,), (256,), (320,)]) + def test_mp3(self, compression): + self._smoke_test("mp3", compression, check_num_frames=False) + + @parameterized.expand([(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,)]) + def test_flac(self, compression): + self._smoke_test("flac", compression, check_num_frames=False) + + @parameterized.expand([(-1,), (0,), (1,), (2,), (3,), (3.6,), (5,), (10,)]) + def test_vorbis(self, compression): + self._smoke_test("vorbis", compression, check_num_frames=False) diff --git a/test/torchaudio_unittest/functional/functional_cuda_test.py b/test/torchaudio_unittest/functional/functional_cuda_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fd35547270ee81b7acf09ce29ba2267265ddfdb0 --- /dev/null +++ b/test/torchaudio_unittest/functional/functional_cuda_test.py @@ -0,0 +1,21 @@ +import torch +import unittest + +from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda +from .functional_impl import Functional + + +@skipIfNoCuda +class TestFunctionalFloat32(Functional, PytorchTestCase): + dtype = torch.float32 + device = torch.device('cuda') + + @unittest.expectedFailure + def test_lfilter_9th_order_filter_stability(self): + super().test_lfilter_9th_order_filter_stability() + + +@skipIfNoCuda +class TestLFilterFloat64(Functional, PytorchTestCase): + dtype = torch.float64 + device = torch.device('cuda') diff --git a/test/torchaudio_unittest/functional/functional_impl.py b/test/torchaudio_unittest/functional/functional_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..c3bebbf0761aaed25bbce7afc8935fecdcf6adca --- /dev/null +++ b/test/torchaudio_unittest/functional/functional_impl.py @@ -0,0 +1,584 @@ +"""Test definition common to CPU and CUDA""" +import math +import itertools +import warnings + +import numpy as np +import torch +import torchaudio.functional as F +from parameterized import parameterized +from scipy import signal + +from torchaudio_unittest.common_utils import ( + TestBaseMixin, + get_sinusoid, + nested_params, + get_whitenoise, + rnnt_utils, +) + + +class Functional(TestBaseMixin): + def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_factor=None, + resampling_method="sinc_interpolation", atol=1e-1, rtol=1e-4): + # resample the signal and compare it to the ground truth + n_to_trim = 20 + sample_rate = 1000 + new_sample_rate = sample_rate + + if up_scale_factor is not None: + new_sample_rate = int(new_sample_rate * up_scale_factor) + + if down_scale_factor is not None: + new_sample_rate = int(new_sample_rate / down_scale_factor) + + duration = 5 # seconds + original_timestamps = torch.arange(0, duration, 1.0 / sample_rate) + + sound = 123 * torch.cos(2 * math.pi * 3 * original_timestamps).unsqueeze(0) + estimate = F.resample(sound, sample_rate, new_sample_rate, + resampling_method=resampling_method).squeeze() + + new_timestamps = torch.arange(0, duration, 1.0 / new_sample_rate)[:estimate.size(0)] + ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps) + + # trim the first/last n samples as these points have boundary effects + ground_truth = ground_truth[..., n_to_trim:-n_to_trim] + estimate = estimate[..., n_to_trim:-n_to_trim] + + self.assertEqual(estimate, ground_truth, atol=atol, rtol=rtol) + + def _test_costs_and_gradients( + self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2 + ): + logits_shape = data["logits"].shape + costs, gradients = rnnt_utils.compute_with_pytorch_transducer(data=data) + self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol) + self.assertEqual(logits_shape, gradients.shape) + self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol) + + def test_lfilter_simple(self): + """ + Create a very basic signal, + Then make a simple 4th order delay + The output should be same as the input but shifted + """ + + torch.random.manual_seed(42) + waveform = torch.rand(2, 44100 * 1, dtype=self.dtype, device=self.device) + b_coeffs = torch.tensor([0, 0, 0, 1], dtype=self.dtype, device=self.device) + a_coeffs = torch.tensor([1, 0, 0, 0], dtype=self.dtype, device=self.device) + output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs) + + self.assertEqual(output_waveform[:, 3:], waveform[:, 0:-3], atol=1e-5, rtol=1e-5) + + def test_lfilter_clamp(self): + input_signal = torch.ones(1, 44100 * 1, dtype=self.dtype, device=self.device) + b_coeffs = torch.tensor([1, 0], dtype=self.dtype, device=self.device) + a_coeffs = torch.tensor([1, -0.95], dtype=self.dtype, device=self.device) + output_signal = F.lfilter(input_signal, a_coeffs, b_coeffs, clamp=True) + assert output_signal.max() <= 1 + output_signal = F.lfilter(input_signal, a_coeffs, b_coeffs, clamp=False) + assert output_signal.max() > 1 + + @parameterized.expand([ + ((44100,), (4,), (44100,)), + ((3, 44100), (4,), (3, 44100,)), + ((2, 3, 44100), (4,), (2, 3, 44100,)), + ((1, 2, 3, 44100), (4,), (1, 2, 3, 44100,)), + ((44100,), (2, 4), (2, 44100)), + ((3, 44100), (1, 4), (3, 1, 44100)), + ((1, 2, 44100), (3, 4), (1, 2, 3, 44100)) + ]) + def test_lfilter_shape(self, input_shape, coeff_shape, target_shape): + torch.random.manual_seed(42) + waveform = torch.rand(*input_shape, dtype=self.dtype, device=self.device) + b_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device) + a_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device) + output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs, batching=False) + assert input_shape == waveform.size() + assert target_shape == output_waveform.size() + + def test_lfilter_9th_order_filter_stability(self): + """ + Validate the precision of lfilter against reference scipy implementation when using high order filter. + The reference implementation use cascaded second-order filters so is more numerically accurate. + """ + # create an impulse signal + x = torch.zeros(1024, dtype=self.dtype, device=self.device) + x[0] = 1 + + # get target impulse response + sos = signal.butter(9, 850, 'hp', fs=22050, output='sos') + y = torch.from_numpy(signal.sosfilt(sos, x.cpu().numpy())).to(self.dtype).to(self.device) + + # get lfilter coefficients + b, a = signal.butter(9, 850, 'hp', fs=22050, output='ba') + b, a = torch.from_numpy(b).to(self.dtype).to(self.device), torch.from_numpy( + a).to(self.dtype).to(self.device) + + # predict impulse response + yhat = F.lfilter(x, a, b, False) + self.assertEqual(yhat, y, atol=1e-4, rtol=1e-5) + + def test_filtfilt_simple(self): + """ + Check that, for an arbitrary signal, applying filtfilt with filter coefficients + corresponding to a pure delay filter imparts no time delay. + """ + waveform = get_whitenoise(sample_rate=8000, n_channels=2, dtype=self.dtype).to( + device=self.device + ) + b_coeffs = torch.tensor([0, 0, 0, 1], dtype=self.dtype, device=self.device) + a_coeffs = torch.tensor([1, 0, 0, 0], dtype=self.dtype, device=self.device) + padded_waveform = torch.cat( + (waveform, torch.zeros(2, 3, dtype=self.dtype, device=self.device)), axis=1 + ) + output_waveform = F.filtfilt(padded_waveform, a_coeffs, b_coeffs) + + self.assertEqual(output_waveform, padded_waveform, atol=1e-5, rtol=1e-5) + + def test_filtfilt_filter_sinusoid(self): + """ + Check that, for a signal comprising two sinusoids, applying filtfilt + with appropriate filter coefficients correctly removes the higher-frequency + sinusoid while imparting no time delay. + """ + T = 1.0 + samples = 1000 + + waveform_k0 = get_sinusoid( + frequency=5, sample_rate=samples // T, dtype=self.dtype, device=self.device + ).squeeze(0) + waveform_k1 = get_sinusoid( + frequency=200, + sample_rate=samples // T, + dtype=self.dtype, + device=self.device, + ).squeeze(0) + waveform = waveform_k0 + waveform_k1 + + # Transfer function numerator and denominator polynomial coefficients + # corresponding to 8th-order Butterworth filter with 100-cycle/T cutoff. + # Generated with + # >>> from scipy import signal + # >>> b_coeffs, a_coeffs = signal.butter(8, 0.2) + b_coeffs = torch.tensor( + [ + 2.39596441e-05, + 1.91677153e-04, + 6.70870035e-04, + 1.34174007e-03, + 1.67717509e-03, + 1.34174007e-03, + 6.70870035e-04, + 1.91677153e-04, + 2.39596441e-05, + ], + dtype=self.dtype, + device=self.device, + ) + a_coeffs = torch.tensor( + [ + 1.0, + -4.78451489, + 10.44504107, + -13.45771989, + 11.12933104, + -6.0252604, + 2.0792738, + -0.41721716, + 0.0372001, + ], + dtype=self.dtype, + device=self.device, + ) + + # Extend waveform in each direction, preserving periodicity. + padded_waveform = torch.cat((waveform[:-1], waveform, waveform[1:])) + + output_waveform = F.filtfilt(padded_waveform, a_coeffs, b_coeffs) + + # Remove padding from output waveform; confirm that result + # closely matches waveform_k0. + self.assertEqual( + output_waveform[samples - 1: 2 * samples - 1], + waveform_k0, + atol=1e-3, + rtol=1e-3, + ) + + @parameterized.expand([(0., ), (1., ), (2., ), (3., )]) + def test_spectogram_grad_at_zero(self, power): + """The gradient of power spectrogram should not be nan but zero near x=0 + + https://github.com/pytorch/audio/issues/993 + """ + x = torch.zeros(1, 22050, requires_grad=True) + spec = F.spectrogram( + x, + pad=0, + window=None, + n_fft=2048, + hop_length=None, + win_length=None, + power=power, + normalized=False, + ) + spec.sum().backward() + assert not x.grad.isnan().sum() + + def test_compute_deltas_one_channel(self): + specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]], dtype=self.dtype, device=self.device) + expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]], dtype=self.dtype, device=self.device) + computed = F.compute_deltas(specgram, win_length=3) + self.assertEqual(computed, expected) + + def test_compute_deltas_two_channels(self): + specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0], + [1.0, 2.0, 3.0, 4.0]]], dtype=self.dtype, device=self.device) + expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5], + [0.5, 1.0, 1.0, 0.5]]], dtype=self.dtype, device=self.device) + computed = F.compute_deltas(specgram, win_length=3) + self.assertEqual(computed, expected) + + @parameterized.expand([(100,), (440,)]) + def test_detect_pitch_frequency_pitch(self, frequency): + sample_rate = 44100 + test_sine_waveform = get_sinusoid( + frequency=frequency, sample_rate=sample_rate, duration=5 + ) + + freq = F.detect_pitch_frequency(test_sine_waveform, sample_rate) + + threshold = 1 + s = ((freq - frequency).abs() > threshold).sum() + self.assertFalse(s) + + @parameterized.expand([([100, 100],), ([2, 100, 100],), ([2, 2, 100, 100],)]) + def test_amplitude_to_DB_reversible(self, shape): + """Round trip between amplitude and db should return the original for various shape + + This implicitly also tests `DB_to_amplitude`. + + """ + amplitude_mult = 20. + power_mult = 10. + amin = 1e-10 + ref = 1.0 + db_mult = math.log10(max(amin, ref)) + + torch.manual_seed(0) + spec = torch.rand(*shape, dtype=self.dtype, device=self.device) * 200 + + # Spectrogram amplitude -> DB -> amplitude + db = F.amplitude_to_DB(spec, amplitude_mult, amin, db_mult, top_db=None) + x2 = F.DB_to_amplitude(db, ref, 0.5) + + self.assertEqual(x2, spec, atol=5e-5, rtol=1e-5) + + # Spectrogram power -> DB -> power + db = F.amplitude_to_DB(spec, power_mult, amin, db_mult, top_db=None) + x2 = F.DB_to_amplitude(db, ref, 1.) + + self.assertEqual(x2, spec) + + @parameterized.expand([([100, 100],), ([2, 100, 100],), ([2, 2, 100, 100],)]) + def test_amplitude_to_DB_top_db_clamp(self, shape): + """Ensure values are properly clamped when `top_db` is supplied.""" + amplitude_mult = 20. + amin = 1e-10 + ref = 1.0 + db_mult = math.log10(max(amin, ref)) + top_db = 40. + + torch.manual_seed(0) + # A random tensor is used for increased entropy, but the max and min for + # each spectrogram still need to be predictable. The max determines the + # decibel cutoff, and the distance from the min must be large enough + # that it triggers a clamp. + spec = torch.rand(*shape, dtype=self.dtype, device=self.device) + # Ensure each spectrogram has a min of 0 and a max of 1. + spec -= spec.amin([-2, -1])[..., None, None] + spec /= spec.amax([-2, -1])[..., None, None] + # Expand the range to (0, 200) - wide enough to properly test clamping. + spec *= 200 + + decibels = F.amplitude_to_DB(spec, amplitude_mult, amin, + db_mult, top_db=top_db) + # Ensure the clamp was applied + below_limit = decibels < 6.0205 + assert not below_limit.any(), ( + "{} decibel values were below the expected cutoff:\n{}".format( + below_limit.sum().item(), decibels + ) + ) + # Ensure it didn't over-clamp + close_to_limit = decibels < 6.0207 + assert close_to_limit.any(), ( + f"No values were close to the limit. Did it over-clamp?\n{decibels}" + ) + + @parameterized.expand( + list(itertools.product([(1, 2, 1025, 400, 2), (1025, 400, 2)], [1, 2, 0.7])) + ) + def test_complex_norm(self, shape, power): + torch.random.manual_seed(42) + complex_tensor = torch.randn(*shape, dtype=self.dtype, device=self.device) + expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2) + norm_tensor = F.complex_norm(complex_tensor, power) + self.assertEqual(norm_tensor, expected_norm_tensor, atol=1e-5, rtol=1e-5) + + @parameterized.expand( + list(itertools.product([(2, 1025, 400), (1, 201, 100)], [100], [0., 30.], [1, 2])) + ) + def test_mask_along_axis(self, shape, mask_param, mask_value, axis): + torch.random.manual_seed(42) + specgram = torch.randn(*shape, dtype=self.dtype, device=self.device) + mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis) + + other_axis = 1 if axis == 2 else 2 + + masked_columns = (mask_specgram == mask_value).sum(other_axis) + num_masked_columns = (masked_columns == mask_specgram.size(other_axis)).sum() + num_masked_columns = torch.div( + num_masked_columns, mask_specgram.size(0), rounding_mode='floor') + + assert mask_specgram.size() == specgram.size() + assert num_masked_columns < mask_param + + @parameterized.expand(list(itertools.product([100], [0., 30.], [2, 3]))) + def test_mask_along_axis_iid(self, mask_param, mask_value, axis): + torch.random.manual_seed(42) + specgrams = torch.randn(4, 2, 1025, 400, dtype=self.dtype, device=self.device) + + mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis) + + other_axis = 2 if axis == 3 else 3 + + masked_columns = (mask_specgrams == mask_value).sum(other_axis) + num_masked_columns = (masked_columns == mask_specgrams.size(other_axis)).sum(-1) + + assert mask_specgrams.size() == specgrams.size() + assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel() + + @parameterized.expand( + list(itertools.product([(2, 1025, 400), (1, 201, 100)], [100], [0., 30.], [1, 2])) + ) + def test_mask_along_axis_preserve(self, shape, mask_param, mask_value, axis): + """mask_along_axis should not alter original input Tensor + + Test is run 5 times to bound the probability of no masking occurring to 1e-10 + See https://github.com/pytorch/audio/issues/1478 + """ + torch.random.manual_seed(42) + for _ in range(5): + specgram = torch.randn(*shape, dtype=self.dtype, device=self.device) + specgram_copy = specgram.clone() + F.mask_along_axis(specgram, mask_param, mask_value, axis) + + self.assertEqual(specgram, specgram_copy) + + @parameterized.expand(list(itertools.product([100], [0., 30.], [2, 3]))) + def test_mask_along_axis_iid_preserve(self, mask_param, mask_value, axis): + """mask_along_axis_iid should not alter original input Tensor + + Test is run 5 times to bound the probability of no masking occurring to 1e-10 + See https://github.com/pytorch/audio/issues/1478 + """ + torch.random.manual_seed(42) + for _ in range(5): + specgrams = torch.randn(4, 2, 1025, 400, dtype=self.dtype, device=self.device) + specgrams_copy = specgrams.clone() + F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis) + + self.assertEqual(specgrams, specgrams_copy) + + @parameterized.expand(list(itertools.product( + ["sinc_interpolation", "kaiser_window"], + [16000, 44100], + ))) + def test_resample_identity(self, resampling_method, sample_rate): + waveform = get_whitenoise(sample_rate=sample_rate, duration=1) + + resampled = F.resample(waveform, sample_rate, sample_rate) + self.assertEqual(waveform, resampled) + + @parameterized.expand([("sinc_interpolation"), ("kaiser_window")]) + def test_resample_waveform_upsample_size(self, resampling_method): + sr = 16000 + waveform = get_whitenoise(sample_rate=sr, duration=0.5,) + upsampled = F.resample(waveform, sr, sr * 2, resampling_method=resampling_method) + assert upsampled.size(-1) == waveform.size(-1) * 2 + + @parameterized.expand([("sinc_interpolation"), ("kaiser_window")]) + def test_resample_waveform_downsample_size(self, resampling_method): + sr = 16000 + waveform = get_whitenoise(sample_rate=sr, duration=0.5,) + downsampled = F.resample(waveform, sr, sr // 2, resampling_method=resampling_method) + assert downsampled.size(-1) == waveform.size(-1) // 2 + + @parameterized.expand([("sinc_interpolation"), ("kaiser_window")]) + def test_resample_waveform_identity_size(self, resampling_method): + sr = 16000 + waveform = get_whitenoise(sample_rate=sr, duration=0.5,) + resampled = F.resample(waveform, sr, sr, resampling_method=resampling_method) + assert resampled.size(-1) == waveform.size(-1) + + @parameterized.expand(list(itertools.product( + ["sinc_interpolation", "kaiser_window"], + list(range(1, 20)), + ))) + def test_resample_waveform_downsample_accuracy(self, resampling_method, i): + self._test_resample_waveform_accuracy(down_scale_factor=i * 2, resampling_method=resampling_method) + + @parameterized.expand(list(itertools.product( + ["sinc_interpolation", "kaiser_window"], + list(range(1, 20)), + ))) + def test_resample_waveform_upsample_accuracy(self, resampling_method, i): + self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0, resampling_method=resampling_method) + + @nested_params( + [0.5, 1.01, 1.3], + [True, False], + ) + def test_phase_vocoder_shape(self, rate, test_pseudo_complex): + """Verify the output shape of phase vocoder""" + hop_length = 256 + num_freq = 1025 + num_frames = 400 + batch_size = 2 + + torch.random.manual_seed(42) + spec = torch.randn( + batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device) + if test_pseudo_complex: + spec = torch.view_as_real(spec) + + phase_advance = torch.linspace( + 0, + np.pi * hop_length, + num_freq, + dtype=self.dtype, device=self.device)[..., None] + + spec_stretch = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance) + + assert spec.dim() == spec_stretch.dim() + expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))]) + output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape + assert output_shape == expected_shape + + @parameterized.expand( + [ + # words + ["", "", 0], # equal + ["abc", "abc", 0], + ["ᑌᑎIᑕO", "ᑌᑎIᑕO", 0], + + ["abc", "", 3], # deletion + ["aa", "aaa", 1], + ["aaa", "aa", 1], + ["ᑌᑎI", "ᑌᑎIᑕO", 2], + + ["aaa", "aba", 1], # substitution + ["aba", "aaa", 1], + ["aba", " ", 3], + + ["abc", "bcd", 2], # mix deletion and substitution + ["0ᑌᑎI", "ᑌᑎIᑕO", 3], + + # sentences + [["hello", "", "Tᕮ᙭T"], ["hello", "", "Tᕮ᙭T"], 0], # equal + [[], [], 0], + + [["hello", "world"], ["hello", "world", "!"], 1], # deletion + [["hello", "world"], ["world"], 1], + [["hello", "world"], [], 2], + + [["Tᕮ᙭T", ], ["world"], 1], # substitution + [["Tᕮ᙭T", "XD"], ["world", "hello"], 2], + [["", "XD"], ["world", ""], 2], + ["aba", " ", 3], + + [["hello", "world"], ["world", "hello", "!"], 2], # mix deletion and substitution + [["Tᕮ᙭T", "world", "LOL", "XD"], ["world", "hello", "ʕ•́ᴥ•̀ʔっ"], 3], + ] + ) + def test_simple_case_edit_distance(self, seq1, seq2, distance): + assert F.edit_distance(seq1, seq2) == distance + assert F.edit_distance(seq2, seq1) == distance + + @nested_params( + [-4, -2, 0, 2, 4], + ) + def test_pitch_shift_shape(self, n_steps): + sample_rate = 16000 + torch.random.manual_seed(42) + waveform = torch.rand(2, 44100 * 1, dtype=self.dtype, device=self.device) + waveform_shift = F.pitch_shift(waveform, sample_rate, n_steps) + assert waveform.size() == waveform_shift.size() + + def test_rnnt_loss_basic_backward(self): + logits, targets, logit_lengths, target_lengths = rnnt_utils.get_basic_data(self.device) + loss = F.rnnt_loss(logits, targets, logit_lengths, target_lengths) + loss.backward() + + def test_rnnt_loss_basic_forward_no_grad(self): + """In early stage, calls to `rnnt_loss` resulted in segmentation fault when + `logits` have `requires_grad = False`. This test makes sure that this no longer + occurs and the functional call runs without error. + + See https://github.com/pytorch/audio/pull/1707 + """ + logits, targets, logit_lengths, target_lengths = rnnt_utils.get_basic_data(self.device) + logits.requires_grad_(False) + F.rnnt_loss(logits, targets, logit_lengths, target_lengths) + + @parameterized.expand([ + (rnnt_utils.get_B1_T2_U3_D5_data, torch.float32, 1e-6, 1e-2), + (rnnt_utils.get_B2_T4_U3_D3_data, torch.float32, 1e-6, 1e-2), + (rnnt_utils.get_B1_T2_U3_D5_data, torch.float16, 1e-3, 1e-2), + (rnnt_utils.get_B2_T4_U3_D3_data, torch.float16, 1e-3, 1e-2), + ]) + def test_rnnt_loss_costs_and_gradients(self, data_func, dtype, atol, rtol): + data, ref_costs, ref_gradients = data_func( + dtype=dtype, + device=self.device, + ) + self._test_costs_and_gradients( + data=data, + ref_costs=ref_costs, + ref_gradients=ref_gradients, + atol=atol, + rtol=rtol, + ) + + def test_rnnt_loss_costs_and_gradients_random_data_with_numpy_fp32(self): + seed = 777 + for i in range(5): + data = rnnt_utils.get_random_data(dtype=torch.float32, device=self.device, seed=(seed + i)) + ref_costs, ref_gradients = rnnt_utils.compute_with_numpy_transducer(data=data) + self._test_costs_and_gradients( + data=data, ref_costs=ref_costs, ref_gradients=ref_gradients + ) + + +class FunctionalCPUOnly(TestBaseMixin): + def test_melscale_fbanks_no_warning_high_n_freq(self): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + F.melscale_fbanks(288, 0, 8000, 128, 16000) + assert len(w) == 0 + + def test_melscale_fbanks_no_warning_low_n_mels(self): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + F.melscale_fbanks(201, 0, 8000, 89, 16000) + assert len(w) == 0 + + def test_melscale_fbanks_warning(self): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + F.melscale_fbanks(201, 0, 8000, 128, 16000) + assert len(w) == 1 diff --git a/test/torchaudio_unittest/functional/kaldi_compatibility_cpu_test.py b/test/torchaudio_unittest/functional/kaldi_compatibility_cpu_test.py new file mode 100644 index 0000000000000000000000000000000000000000..90a634746c0a5276138902e852e269fad258d7fe --- /dev/null +++ b/test/torchaudio_unittest/functional/kaldi_compatibility_cpu_test.py @@ -0,0 +1,19 @@ +import torch + +from torchaudio_unittest.common_utils import PytorchTestCase +from .kaldi_compatibility_test_impl import Kaldi, KaldiCPUOnly + + +class TestKaldiCPUOnly(KaldiCPUOnly, PytorchTestCase): + dtype = torch.float32 + device = torch.device('cpu') + + +class TestKaldiFloat32(Kaldi, PytorchTestCase): + dtype = torch.float32 + device = torch.device('cpu') + + +class TestKaldiFloat64(Kaldi, PytorchTestCase): + dtype = torch.float64 + device = torch.device('cpu') diff --git a/test/torchaudio_unittest/functional/kaldi_compatibility_cuda_test.py b/test/torchaudio_unittest/functional/kaldi_compatibility_cuda_test.py new file mode 100644 index 0000000000000000000000000000000000000000..47a1bea338cea30106cb203bf94648020b135c7a --- /dev/null +++ b/test/torchaudio_unittest/functional/kaldi_compatibility_cuda_test.py @@ -0,0 +1,16 @@ +import torch + +from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda +from .kaldi_compatibility_test_impl import Kaldi + + +@skipIfNoCuda +class TestKaldiFloat32(Kaldi, PytorchTestCase): + dtype = torch.float32 + device = torch.device('cuda') + + +@skipIfNoCuda +class TestKaldiFloat64(Kaldi, PytorchTestCase): + dtype = torch.float64 + device = torch.device('cuda') diff --git a/test/torchaudio_unittest/functional/kaldi_compatibility_test_impl.py b/test/torchaudio_unittest/functional/kaldi_compatibility_test_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..bad30afd9a2a9aced635f746319992e812e26959 --- /dev/null +++ b/test/torchaudio_unittest/functional/kaldi_compatibility_test_impl.py @@ -0,0 +1,60 @@ +from parameterized import parameterized +import torch +import torchaudio.functional as F + +from torchaudio_unittest.common_utils import ( + get_sinusoid, + load_params, + save_wav, + skipIfNoExec, + TempDirMixin, + TestBaseMixin, +) +from torchaudio_unittest.common_utils.kaldi_utils import ( + convert_args, + run_kaldi, +) + + +class Kaldi(TempDirMixin, TestBaseMixin): + def assert_equal(self, output, *, expected, rtol=None, atol=None): + expected = expected.to(dtype=self.dtype, device=self.device) + self.assertEqual(output, expected, rtol=rtol, atol=atol) + + @skipIfNoExec('apply-cmvn-sliding') + def test_sliding_window_cmn(self): + """sliding_window_cmn should be numerically compatible with apply-cmvn-sliding""" + kwargs = { + 'cmn_window': 600, + 'min_cmn_window': 100, + 'center': False, + 'norm_vars': False, + } + + tensor = torch.randn(40, 10, dtype=self.dtype, device=self.device) + result = F.sliding_window_cmn(tensor, **kwargs) + command = ['apply-cmvn-sliding'] + convert_args(**kwargs) + ['ark:-', 'ark:-'] + kaldi_result = run_kaldi(command, 'ark', tensor) + self.assert_equal(result, expected=kaldi_result) + + +class KaldiCPUOnly(TempDirMixin, TestBaseMixin): + def assert_equal(self, output, *, expected, rtol=None, atol=None): + expected = expected.to(dtype=self.dtype, device=self.device) + self.assertEqual(output, expected, rtol=rtol, atol=atol) + + @parameterized.expand(load_params('kaldi_test_pitch_args.jsonl')) + @skipIfNoExec('compute-kaldi-pitch-feats') + def test_pitch_feats(self, kwargs): + """compute_kaldi_pitch produces numerically compatible result with compute-kaldi-pitch-feats""" + sample_rate = kwargs['sample_rate'] + waveform = get_sinusoid(dtype='float32', sample_rate=sample_rate) + result = F.compute_kaldi_pitch(waveform[0], **kwargs) + + waveform = get_sinusoid(dtype='int16', sample_rate=sample_rate) + wave_file = self.get_temp_path('test.wav') + save_wav(wave_file, waveform, sample_rate) + + command = ['compute-kaldi-pitch-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-'] + kaldi_result = run_kaldi(command, 'scp', wave_file) + self.assert_equal(result, expected=kaldi_result) diff --git a/test/torchaudio_unittest/functional/librosa_compatibility_cpu_test.py b/test/torchaudio_unittest/functional/librosa_compatibility_cpu_test.py new file mode 100644 index 0000000000000000000000000000000000000000..dcfe6fcf4d58a3969a1f2efd54f12f19124a4c75 --- /dev/null +++ b/test/torchaudio_unittest/functional/librosa_compatibility_cpu_test.py @@ -0,0 +1,10 @@ +from torchaudio_unittest.common_utils import PytorchTestCase +from .librosa_compatibility_test_impl import Functional, FunctionalComplex + + +class TestFunctionalCPU(Functional, PytorchTestCase): + device = 'cpu' + + +class TestFunctionalComplexCPU(FunctionalComplex, PytorchTestCase): + device = 'cpu' diff --git a/test/torchaudio_unittest/functional/librosa_compatibility_cuda_test.py b/test/torchaudio_unittest/functional/librosa_compatibility_cuda_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9f42b4dc84bab0eb2129e73b05cf95ecd6a435f0 --- /dev/null +++ b/test/torchaudio_unittest/functional/librosa_compatibility_cuda_test.py @@ -0,0 +1,12 @@ +from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda +from .librosa_compatibility_test_impl import Functional, FunctionalComplex + + +@skipIfNoCuda +class TestFunctionalCUDA(Functional, PytorchTestCase): + device = 'cuda' + + +@skipIfNoCuda +class TestFunctionalComplexCUDA(FunctionalComplex, PytorchTestCase): + device = 'cuda' diff --git a/test/torchaudio_unittest/functional/librosa_compatibility_test_impl.py b/test/torchaudio_unittest/functional/librosa_compatibility_test_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..a63f0da9d4cc9acafd61b439a437bf639b58bf3e --- /dev/null +++ b/test/torchaudio_unittest/functional/librosa_compatibility_test_impl.py @@ -0,0 +1,161 @@ +import unittest +from distutils.version import StrictVersion + +import torch +from parameterized import param + +import torchaudio.functional as F +from torchaudio._internal.module_utils import is_module_available + +LIBROSA_AVAILABLE = is_module_available('librosa') + +if LIBROSA_AVAILABLE: + import numpy as np + import librosa + + +from torchaudio_unittest.common_utils import ( + TestBaseMixin, + nested_params, + get_whitenoise, + get_spectrogram, +) + + +@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available") +class Functional(TestBaseMixin): + """Test suite for functions in `functional` module.""" + dtype = torch.float64 + + @nested_params([0, 0.99]) + def test_griffinlim(self, momentum): + # FFT params + n_fft = 400 + win_length = n_fft + hop_length = n_fft // 4 + window = torch.hann_window(win_length, device=self.device) + power = 1 + # GriffinLim params + n_iter = 8 + + waveform = get_whitenoise(device=self.device, dtype=self.dtype) + specgram = get_spectrogram( + waveform, n_fft=n_fft, hop_length=hop_length, power=power, + win_length=win_length, window=window) + + result = F.griffinlim( + specgram, + window=window, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + power=power, + n_iter=n_iter, + momentum=momentum, + length=waveform.size(1), + rand_init=False) + expected = librosa.griffinlim( + specgram[0].cpu().numpy(), + n_iter=n_iter, + hop_length=hop_length, + momentum=momentum, + init=None, + length=waveform.size(1))[None, ...] + self.assertEqual(result, torch.from_numpy(expected), atol=5e-5, rtol=1e-07) + + @nested_params( + [ + param(), + param(n_mels=128, sample_rate=44100), + param(n_mels=128, fmin=2000.0, fmax=5000.0), + param(n_mels=56, fmin=100.0, fmax=9000.0), + param(n_mels=56, fmin=800.0, fmax=900.0), + param(n_mels=56, fmin=1900.0, fmax=900.0), + param(n_mels=10, fmin=1900.0, fmax=900.0), + ], + [param(norm=n) for n in [None, 'slaney']], + [param(mel_scale=s) for s in ['htk', 'slaney']], + ) + def test_create_mel_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, + fmin=0.0, fmax=8000.0, norm=None, mel_scale="htk"): + if (norm == "slaney" and StrictVersion(librosa.__version__) < StrictVersion("0.7.2")): + self.skipTest('Test is known to fail with older versions of librosa.') + if self.device != 'cpu': + self.skipTest('No need to run this test on CUDA') + + expected = librosa.filters.mel( + sr=sample_rate, + n_fft=n_fft, + n_mels=n_mels, + fmax=fmax, + fmin=fmin, + htk=mel_scale == "htk", + norm=norm).T + result = F.melscale_fbanks( + sample_rate=sample_rate, + n_mels=n_mels, + f_max=fmax, + f_min=fmin, + n_freqs=(n_fft // 2 + 1), + norm=norm, + mel_scale=mel_scale) + self.assertEqual(result, torch.from_numpy(expected), atol=7e-5, rtol=1.3e-6) + + def test_amplitude_to_DB_power(self): + amin = 1e-10 + db_multiplier = 0.0 + top_db = 80.0 + multiplier = 10.0 + + spec = get_spectrogram(get_whitenoise(device=self.device, dtype=self.dtype), power=2) + result = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db) + expected = librosa.core.power_to_db(spec[0].cpu().numpy())[None, ...] + self.assertEqual(result, torch.from_numpy(expected)) + + def test_amplitude_to_DB(self): + amin = 1e-10 + db_multiplier = 0.0 + top_db = 80.0 + multiplier = 20.0 + + spec = get_spectrogram(get_whitenoise(device=self.device, dtype=self.dtype), power=1) + result = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db) + expected = librosa.core.amplitude_to_db(spec[0].cpu().numpy())[None, ...] + self.assertEqual(result, torch.from_numpy(expected)) + + +@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available") +class FunctionalComplex(TestBaseMixin): + @nested_params( + [0.5, 1.01, 1.3], + [True, False], + ) + def test_phase_vocoder(self, rate, test_pseudo_complex): + hop_length = 256 + num_freq = 1025 + num_frames = 400 + torch.random.manual_seed(42) + + # Due to cummulative sum, numerical error in using torch.float32 will + # result in bottom right values of the stretched sectrogram to not + # match with librosa. + spec = torch.randn(num_freq, num_frames, device=self.device, dtype=torch.complex128) + phase_advance = torch.linspace( + 0, + np.pi * hop_length, + num_freq, + device=self.device, + dtype=torch.float64)[..., None] + + stretched = F.phase_vocoder( + torch.view_as_real(spec) if test_pseudo_complex else spec, + rate=rate, phase_advance=phase_advance) + + expected_stretched = librosa.phase_vocoder( + spec.cpu().numpy(), + rate=rate, + hop_length=hop_length) + + self.assertEqual( + torch.view_as_complex(stretched) if test_pseudo_complex else stretched, + torch.from_numpy(expected_stretched)) diff --git a/test/torchaudio_unittest/functional/sox_compatibility_test.py b/test/torchaudio_unittest/functional/sox_compatibility_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fe9744f22a00ace7e29b7caa7b5c53d7d5854b16 --- /dev/null +++ b/test/torchaudio_unittest/functional/sox_compatibility_test.py @@ -0,0 +1,299 @@ +import torch +import torchaudio.functional as F + +from torchaudio_unittest.common_utils import ( + skipIfNoSox, + skipIfNoExec, + TempDirMixin, + TorchaudioTestCase, + get_asset_path, + sox_utils, + load_wav, + save_wav, + get_whitenoise, +) + + +@skipIfNoSox +@skipIfNoExec('sox') +class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): + def run_sox_effect(self, input_file, effect): + output_file = self.get_temp_path('expected.wav') + sox_utils.run_sox_effect(input_file, output_file, [str(e) for e in effect]) + return load_wav(output_file) + + def assert_sox_effect(self, result, input_path, effects, atol=1e-04, rtol=1e-5): + expected, _ = self.run_sox_effect(input_path, effects) + self.assertEqual(result, expected, atol=atol, rtol=rtol) + + def get_whitenoise(self, sample_rate=8000): + noise = get_whitenoise( + sample_rate=sample_rate, duration=3, scale_factor=0.9, + ) + path = self.get_temp_path("whitenoise.wav") + save_wav(path, noise, sample_rate) + return noise, path + + def test_gain(self): + path = get_asset_path('steam-train-whistle-daniel_simon.wav') + data, _ = load_wav(path) + result = F.gain(data, 3) + self.assert_sox_effect(result, path, ['gain', 3]) + + def test_dither(self): + path = get_asset_path('steam-train-whistle-daniel_simon.wav') + data, _ = load_wav(path) + result = F.dither(data) + self.assert_sox_effect(result, path, ['dither']) + + def test_dither_noise(self): + path = get_asset_path('steam-train-whistle-daniel_simon.wav') + data, _ = load_wav(path) + result = F.dither(data, noise_shaping=True) + self.assert_sox_effect(result, path, ['dither', '-s'], atol=1.5e-4) + + def test_lowpass(self): + cutoff_freq = 3000 + sample_rate = 8000 + + data, path = self.get_whitenoise(sample_rate) + result = F.lowpass_biquad(data, sample_rate, cutoff_freq) + self.assert_sox_effect(result, path, ['lowpass', cutoff_freq], atol=1.5e-4) + + def test_highpass(self): + cutoff_freq = 2000 + sample_rate = 8000 + + data, path = self.get_whitenoise(sample_rate) + result = F.highpass_biquad(data, sample_rate, cutoff_freq) + self.assert_sox_effect(result, path, ['highpass', cutoff_freq], atol=1.5e-4) + + def test_allpass(self): + central_freq = 1000 + q = 0.707 + sample_rate = 8000 + + data, path = self.get_whitenoise(sample_rate) + result = F.allpass_biquad(data, sample_rate, central_freq, q) + self.assert_sox_effect(result, path, ['allpass', central_freq, f'{q}q']) + + def test_bandpass_with_csg(self): + central_freq = 1000 + q = 0.707 + const_skirt_gain = True + sample_rate = 8000 + + data, path = self.get_whitenoise(sample_rate) + result = F.bandpass_biquad(data, sample_rate, central_freq, q, const_skirt_gain) + self.assert_sox_effect(result, path, ['bandpass', '-c', central_freq, f'{q}q']) + + def test_bandpass_without_csg(self): + central_freq = 1000 + q = 0.707 + const_skirt_gain = False + sample_rate = 8000 + + data, path = self.get_whitenoise(sample_rate) + result = F.bandpass_biquad(data, sample_rate, central_freq, q, const_skirt_gain) + self.assert_sox_effect(result, path, ['bandpass', central_freq, f'{q}q']) + + def test_bandreject(self): + central_freq = 1000 + q = 0.707 + sample_rate = 8000 + + data, path = self.get_whitenoise(sample_rate) + result = F.bandreject_biquad(data, sample_rate, central_freq, q) + self.assert_sox_effect(result, path, ['bandreject', central_freq, f'{q}q']) + + def test_band_with_noise(self): + central_freq = 1000 + q = 0.707 + noise = True + sample_rate = 8000 + + data, path = self.get_whitenoise(sample_rate) + result = F.band_biquad(data, sample_rate, central_freq, q, noise) + self.assert_sox_effect(result, path, ['band', '-n', central_freq, f'{q}q']) + + def test_band_without_noise(self): + central_freq = 1000 + q = 0.707 + noise = False + sample_rate = 8000 + + data, path = self.get_whitenoise(sample_rate) + result = F.band_biquad(data, sample_rate, central_freq, q, noise) + self.assert_sox_effect(result, path, ['band', central_freq, f'{q}q']) + + def test_treble(self): + central_freq = 1000 + q = 0.707 + gain = 40 + sample_rate = 8000 + + data, path = self.get_whitenoise(sample_rate) + result = F.treble_biquad(data, sample_rate, gain, central_freq, q) + self.assert_sox_effect(result, path, ['treble', gain, central_freq, f'{q}q']) + + def test_bass(self): + central_freq = 1000 + q = 0.707 + gain = 40 + sample_rate = 8000 + + data, path = self.get_whitenoise(sample_rate) + result = F.bass_biquad(data, sample_rate, gain, central_freq, q) + self.assert_sox_effect(result, path, ['bass', gain, central_freq, f'{q}q'], atol=1.5e-4) + + def test_deemph(self): + sample_rate = 44100 + data, path = self.get_whitenoise(sample_rate) + result = F.deemph_biquad(data, sample_rate) + self.assert_sox_effect(result, path, ['deemph']) + + def test_riaa(self): + sample_rate = 44100 + data, path = self.get_whitenoise(sample_rate) + result = F.riaa_biquad(data, sample_rate) + self.assert_sox_effect(result, path, ['riaa']) + + def test_contrast(self): + enhancement_amount = 80. + + data, path = self.get_whitenoise() + result = F.contrast(data, enhancement_amount) + self.assert_sox_effect(result, path, ['contrast', enhancement_amount]) + + def test_dcshift_with_limiter(self): + shift = 0.5 + limiter_gain = 0.05 + + data, path = self.get_whitenoise() + result = F.dcshift(data, shift, limiter_gain) + self.assert_sox_effect(result, path, ['dcshift', shift, limiter_gain]) + + def test_dcshift_without_limiter(self): + shift = 0.6 + + data, path = self.get_whitenoise() + result = F.dcshift(data, shift) + self.assert_sox_effect(result, path, ['dcshift', shift]) + + def test_overdrive(self): + gain = 30 + colour = 40 + + data, path = self.get_whitenoise() + result = F.overdrive(data, gain, colour) + self.assert_sox_effect(result, path, ['overdrive', gain, colour]) + + def test_phaser_sine(self): + gain_in = 0.5 + gain_out = 0.8 + delay_ms = 2.0 + decay = 0.4 + speed = 0.5 + sample_rate = 8000 + + data, path = self.get_whitenoise(sample_rate) + result = F.phaser(data, sample_rate, gain_in, gain_out, delay_ms, decay, speed, sinusoidal=True) + self.assert_sox_effect(result, path, ['phaser', gain_in, gain_out, delay_ms, decay, speed, '-s']) + + def test_phaser_triangle(self): + gain_in = 0.5 + gain_out = 0.8 + delay_ms = 2.0 + decay = 0.4 + speed = 0.5 + sample_rate = 8000 + + data, path = self.get_whitenoise(sample_rate) + result = F.phaser(data, sample_rate, gain_in, gain_out, delay_ms, decay, speed, sinusoidal=False) + self.assert_sox_effect(result, path, ['phaser', gain_in, gain_out, delay_ms, decay, speed, '-t']) + + def test_flanger_triangle_linear(self): + delay = 0.6 + depth = 0.87 + regen = 3.0 + width = 0.9 + speed = 0.5 + phase = 30 + sample_rate = 8000 + + data, path = self.get_whitenoise(sample_rate) + result = F.flanger( + data, sample_rate, delay, depth, regen, width, speed, phase, + modulation='triangular', interpolation='linear') + self.assert_sox_effect( + result, path, ['flanger', delay, depth, regen, width, speed, 'triangle', phase, 'linear']) + + def test_flanger_triangle_quad(self): + delay = 0.8 + depth = 0.88 + regen = 3.0 + width = 0.4 + speed = 0.5 + phase = 40 + sample_rate = 8000 + + data, path = self.get_whitenoise(sample_rate) + result = F.flanger( + data, sample_rate, delay, depth, regen, width, speed, phase, + modulation='triangular', interpolation='quadratic') + self.assert_sox_effect( + result, path, ['flanger', delay, depth, regen, width, speed, 'triangle', phase, 'quadratic']) + + def test_flanger_sine_linear(self): + delay = 0.8 + depth = 0.88 + regen = 3.0 + width = 0.23 + speed = 1.3 + phase = 60 + sample_rate = 8000 + + data, path = self.get_whitenoise(sample_rate) + result = F.flanger( + data, sample_rate, delay, depth, regen, width, speed, phase, + modulation='sinusoidal', interpolation='linear') + self.assert_sox_effect( + result, path, ['flanger', delay, depth, regen, width, speed, 'sine', phase, 'linear']) + + def test_flanger_sine_quad(self): + delay = 0.9 + depth = 0.9 + regen = 4.0 + width = 0.23 + speed = 1.3 + phase = 25 + sample_rate = 8000 + + data, path = self.get_whitenoise(sample_rate) + result = F.flanger( + data, sample_rate, delay, depth, regen, width, speed, phase, + modulation='sinusoidal', interpolation='quadratic') + self.assert_sox_effect( + result, path, ['flanger', delay, depth, regen, width, speed, 'sine', phase, 'quadratic']) + + def test_equalizer(self): + center_freq = 300 + q = 0.707 + gain = 1 + sample_rate = 8000 + + data, path = self.get_whitenoise(sample_rate) + result = F.equalizer_biquad(data, sample_rate, center_freq, gain, q) + self.assert_sox_effect(result, path, ['equalizer', center_freq, q, gain]) + + def test_perf_biquad_filtering(self): + b0 = 0.4 + b1 = 0.2 + b2 = 0.9 + a0 = 0.7 + a1 = 0.2 + a2 = 0.6 + + data, path = self.get_whitenoise() + result = F.lfilter(data, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2])) + self.assert_sox_effect(result, path, ['biquad', b0, b1, b2, a0, a1, a2]) diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_cpu_test.py b/test/torchaudio_unittest/functional/torchscript_consistency_cpu_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2971067dbe522fb9308be26f480412c4ea4db8d6 --- /dev/null +++ b/test/torchaudio_unittest/functional/torchscript_consistency_cpu_test.py @@ -0,0 +1,14 @@ +import torch + +from torchaudio_unittest.common_utils import PytorchTestCase +from .torchscript_consistency_impl import Functional, FunctionalFloat32Only + + +class TestFunctionalFloat32(Functional, FunctionalFloat32Only, PytorchTestCase): + dtype = torch.float32 + device = torch.device('cpu') + + +class TestFunctionalFloat64(Functional, PytorchTestCase): + dtype = torch.float64 + device = torch.device('cpu') diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py b/test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py new file mode 100644 index 0000000000000000000000000000000000000000..30e2b3969940900d8b3e38823fab8a10ee79313c --- /dev/null +++ b/test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py @@ -0,0 +1,16 @@ +import torch + +from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase +from .torchscript_consistency_impl import Functional, FunctionalFloat32Only + + +@skipIfNoCuda +class TestFunctionalFloat32(Functional, FunctionalFloat32Only, PytorchTestCase): + dtype = torch.float32 + device = torch.device('cuda') + + +@skipIfNoCuda +class TestFunctionalFloat64(Functional, PytorchTestCase): + dtype = torch.float64 + device = torch.device('cuda') diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..4097b20e34da27f6f3661ea393a630c9395558f1 --- /dev/null +++ b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py @@ -0,0 +1,718 @@ +"""Test suites for jit-ability and its numerical compatibility""" +import unittest + +import torch +import torchaudio.functional as F +from parameterized import parameterized + +from torchaudio_unittest import common_utils +from torchaudio_unittest.common_utils import ( + TempDirMixin, + TestBaseMixin, + skipIfRocm, + torch_script, +) + + +class Functional(TempDirMixin, TestBaseMixin): + """Implements test for `functional` module that are performed for different devices""" + def _assert_consistency(self, func, tensor, shape_only=False): + tensor = tensor.to(device=self.device, dtype=self.dtype) + ts_func = torch_script(func) + + torch.random.manual_seed(40) + output = func(tensor) + + torch.random.manual_seed(40) + ts_output = ts_func(tensor) + + if shape_only: + ts_output = ts_output.shape + output = output.shape + self.assertEqual(ts_output, output) + + def _assert_consistency_complex(self, func, tensor, test_pseudo_complex=False): + assert tensor.is_complex() + tensor = tensor.to(device=self.device, dtype=self.complex_dtype) + ts_func = torch_script(func) + + if test_pseudo_complex: + tensor = torch.view_as_real(tensor) + + torch.random.manual_seed(40) + output = func(tensor) + + torch.random.manual_seed(40) + ts_output = ts_func(tensor) + + self.assertEqual(ts_output, output) + + def test_spectrogram_complex(self): + def func(tensor): + n_fft = 400 + ws = 400 + hop = 200 + pad = 0 + window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype) + power = None + normalize = False + return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize) + + tensor = common_utils.get_whitenoise() + self._assert_consistency(func, tensor) + + def test_spectrogram_real(self): + def func(tensor): + n_fft = 400 + ws = 400 + hop = 200 + pad = 0 + window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype) + power = 2. + normalize = False + return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize, return_complex=False) + + tensor = common_utils.get_whitenoise() + self._assert_consistency(func, tensor) + + def test_inverse_spectrogram_complex(self): + def func(tensor): + length = 400 + n_fft = 400 + hop = 200 + ws = 400 + pad = 0 + window = torch.hann_window(ws, device=tensor.device, dtype=torch.float64) + normalize = False + return F.inverse_spectrogram(tensor, length, pad, window, n_fft, hop, ws, normalize) + + waveform = common_utils.get_whitenoise(sample_rate=8000, duration=0.05) + tensor = common_utils.get_spectrogram(waveform, n_fft=400, hop_length=200) + self._assert_consistency_complex(func, tensor) + + def test_inverse_spectrogram_real(self): + def func(tensor): + length = 400 + n_fft = 400 + hop = 200 + ws = 400 + pad = 0 + window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype) + normalize = False + return F.inverse_spectrogram(tensor, length, pad, window, n_fft, hop, ws, normalize) + + waveform = common_utils.get_whitenoise(sample_rate=8000, duration=0.05) + tensor = common_utils.get_spectrogram(waveform, n_fft=400, hop_length=200) + tensor = torch.view_as_real(tensor) + self._assert_consistency(func, tensor) + + @skipIfRocm + def test_griffinlim(self): + def func(tensor): + n_fft = 400 + ws = 400 + hop = 200 + window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype) + power = 2. + momentum = 0.99 + n_iter = 32 + length = 1000 + rand_int = False + return F.griffinlim(tensor, window, n_fft, hop, ws, power, n_iter, momentum, length, rand_int) + + tensor = torch.rand((1, 201, 6)) + self._assert_consistency(func, tensor) + + def test_compute_deltas(self): + def func(tensor): + win_length = 2 * 7 + 1 + return F.compute_deltas(tensor, win_length=win_length) + + channel = 13 + n_mfcc = channel * 3 + time = 1021 + tensor = torch.randn(channel, n_mfcc, time) + self._assert_consistency(func, tensor) + + def test_detect_pitch_frequency(self): + waveform = common_utils.get_sinusoid(sample_rate=44100) + + def func(tensor): + sample_rate = 44100 + return F.detect_pitch_frequency(tensor, sample_rate) + + self._assert_consistency(func, waveform) + + def test_melscale_fbanks(self): + if self.device != torch.device('cpu'): + raise unittest.SkipTest('No need to perform test on device other than CPU') + + def func(_): + n_stft = 100 + f_min = 0.0 + f_max = 20.0 + n_mels = 10 + sample_rate = 16000 + norm = "slaney" + return F.melscale_fbanks(n_stft, f_min, f_max, n_mels, sample_rate, norm) + + dummy = torch.zeros(1, 1) + self._assert_consistency(func, dummy) + + def test_linear_fbanks(self): + if self.device != torch.device('cpu'): + raise unittest.SkipTest('No need to perform test on device other than CPU') + + def func(_): + n_stft = 100 + f_min = 0.0 + f_max = 20.0 + n_filter = 10 + sample_rate = 16000 + return F.linear_fbanks(n_stft, f_min, f_max, n_filter, sample_rate) + + dummy = torch.zeros(1, 1) + self._assert_consistency(func, dummy) + + def test_amplitude_to_DB(self): + def func(tensor): + multiplier = 10.0 + amin = 1e-10 + db_multiplier = 0.0 + top_db = 80.0 + return F.amplitude_to_DB(tensor, multiplier, amin, db_multiplier, top_db) + + tensor = torch.rand((6, 201)) + self._assert_consistency(func, tensor) + + def test_DB_to_amplitude(self): + def func(tensor): + ref = 1. + power = 1. + return F.DB_to_amplitude(tensor, ref, power) + + tensor = torch.rand((1, 100)) + self._assert_consistency(func, tensor) + + def test_create_dct(self): + if self.device != torch.device('cpu'): + raise unittest.SkipTest('No need to perform test on device other than CPU') + + def func(_): + n_mfcc = 40 + n_mels = 128 + norm = "ortho" + return F.create_dct(n_mfcc, n_mels, norm) + + dummy = torch.zeros(1, 1) + self._assert_consistency(func, dummy) + + def test_mu_law_encoding(self): + def func(tensor): + qc = 256 + return F.mu_law_encoding(tensor, qc) + + waveform = common_utils.get_whitenoise() + self._assert_consistency(func, waveform) + + def test_mu_law_decoding(self): + def func(tensor): + qc = 256 + return F.mu_law_decoding(tensor, qc) + + tensor = torch.rand((1, 10)) + self._assert_consistency(func, tensor) + + def test_complex_norm(self): + def func(tensor): + power = 2. + return F.complex_norm(tensor, power) + + tensor = torch.randn(1, 2, 1025, 400, 2) + self._assert_consistency(func, tensor) + + def test_mask_along_axis(self): + def func(tensor): + mask_param = 100 + mask_value = 30. + axis = 2 + return F.mask_along_axis(tensor, mask_param, mask_value, axis) + + tensor = torch.randn(2, 1025, 400) + self._assert_consistency(func, tensor) + + def test_mask_along_axis_iid(self): + def func(tensor): + mask_param = 100 + mask_value = 30. + axis = 2 + return F.mask_along_axis_iid(tensor, mask_param, mask_value, axis) + + tensor = torch.randn(4, 2, 1025, 400) + self._assert_consistency(func, tensor) + + def test_gain(self): + def func(tensor): + gainDB = 2.0 + return F.gain(tensor, gainDB) + + tensor = torch.rand((1, 1000)) + self._assert_consistency(func, tensor) + + def test_dither_TPDF(self): + def func(tensor): + return F.dither(tensor, 'TPDF') + + tensor = common_utils.get_whitenoise(n_channels=2) + self._assert_consistency(func, tensor, shape_only=True) + + def test_dither_RPDF(self): + def func(tensor): + return F.dither(tensor, 'RPDF') + + tensor = common_utils.get_whitenoise(n_channels=2) + self._assert_consistency(func, tensor, shape_only=True) + + def test_dither_GPDF(self): + def func(tensor): + return F.dither(tensor, 'GPDF') + + tensor = common_utils.get_whitenoise(n_channels=2) + self._assert_consistency(func, tensor, shape_only=True) + + def test_dither_noise_shaping(self): + def func(tensor): + return F.dither(tensor, noise_shaping=True) + + tensor = common_utils.get_whitenoise(n_channels=2) + self._assert_consistency(func, tensor) + + def test_lfilter(self): + if self.dtype == torch.float64: + raise unittest.SkipTest("This test is known to fail for float64") + + waveform = common_utils.get_whitenoise() + + def func(tensor): + # Design an IIR lowpass filter using scipy.signal filter design + # https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign + # + # Example + # >>> from scipy.signal import iirdesign + # >>> b, a = iirdesign(0.2, 0.3, 1, 60) + b_coeffs = torch.tensor( + [ + 0.00299893, + -0.0051152, + 0.00841964, + -0.00747802, + 0.00841964, + -0.0051152, + 0.00299893, + ], + device=tensor.device, + dtype=tensor.dtype, + ) + a_coeffs = torch.tensor( + [ + 1.0, + -4.8155751, + 10.2217618, + -12.14481273, + 8.49018171, + -3.3066882, + 0.56088705, + ], + device=tensor.device, + dtype=tensor.dtype, + ) + return F.lfilter(tensor, a_coeffs, b_coeffs) + + self._assert_consistency(func, waveform) + + def test_filtfilt(self): + def func(tensor): + torch.manual_seed(296) + b_coeffs = torch.rand(4, device=tensor.device, dtype=tensor.dtype) + a_coeffs = torch.rand(4, device=tensor.device, dtype=tensor.dtype) + return F.filtfilt(tensor, a_coeffs, b_coeffs) + + waveform = common_utils.get_whitenoise(sample_rate=8000) + self._assert_consistency(func, waveform) + + def test_lowpass(self): + if self.dtype == torch.float64: + raise unittest.SkipTest("This test is known to fail for float64") + + waveform = common_utils.get_whitenoise(sample_rate=44100) + + def func(tensor): + sample_rate = 44100 + cutoff_freq = 3000. + return F.lowpass_biquad(tensor, sample_rate, cutoff_freq) + + self._assert_consistency(func, waveform) + + def test_highpass(self): + if self.dtype == torch.float64: + raise unittest.SkipTest("This test is known to fail for float64") + + waveform = common_utils.get_whitenoise(sample_rate=44100) + + def func(tensor): + sample_rate = 44100 + cutoff_freq = 2000. + return F.highpass_biquad(tensor, sample_rate, cutoff_freq) + + self._assert_consistency(func, waveform) + + def test_allpass(self): + if self.dtype == torch.float64: + raise unittest.SkipTest("This test is known to fail for float64") + + waveform = common_utils.get_whitenoise(sample_rate=44100) + + def func(tensor): + sample_rate = 44100 + central_freq = 1000. + q = 0.707 + return F.allpass_biquad(tensor, sample_rate, central_freq, q) + + self._assert_consistency(func, waveform) + + def test_bandpass_with_csg(self): + if self.dtype == torch.float64: + raise unittest.SkipTest("This test is known to fail for float64") + + waveform = common_utils.get_whitenoise(sample_rate=44100) + + def func(tensor): + sample_rate = 44100 + central_freq = 1000. + q = 0.707 + const_skirt_gain = True + return F.bandpass_biquad(tensor, sample_rate, central_freq, q, const_skirt_gain) + + self._assert_consistency(func, waveform) + + def test_bandpass_without_csg(self): + if self.dtype == torch.float64: + raise unittest.SkipTest("This test is known to fail for float64") + + waveform = common_utils.get_whitenoise(sample_rate=44100) + + def func(tensor): + sample_rate = 44100 + central_freq = 1000. + q = 0.707 + const_skirt_gain = True + return F.bandpass_biquad(tensor, sample_rate, central_freq, q, const_skirt_gain) + + self._assert_consistency(func, waveform) + + def test_bandreject(self): + if self.dtype == torch.float64: + raise unittest.SkipTest("This test is known to fail for float64") + + waveform = common_utils.get_whitenoise(sample_rate=44100) + + def func(tensor): + sample_rate = 44100 + central_freq = 1000. + q = 0.707 + return F.bandreject_biquad(tensor, sample_rate, central_freq, q) + + self._assert_consistency(func, waveform) + + def test_band_with_noise(self): + if self.dtype == torch.float64: + raise unittest.SkipTest("This test is known to fail for float64") + + waveform = common_utils.get_whitenoise(sample_rate=44100) + + def func(tensor): + sample_rate = 44100 + central_freq = 1000. + q = 0.707 + noise = True + return F.band_biquad(tensor, sample_rate, central_freq, q, noise) + + self._assert_consistency(func, waveform) + + def test_band_without_noise(self): + if self.dtype == torch.float64: + raise unittest.SkipTest("This test is known to fail for float64") + + waveform = common_utils.get_whitenoise(sample_rate=44100) + + def func(tensor): + sample_rate = 44100 + central_freq = 1000. + q = 0.707 + noise = False + return F.band_biquad(tensor, sample_rate, central_freq, q, noise) + + self._assert_consistency(func, waveform) + + def test_treble(self): + if self.dtype == torch.float64: + raise unittest.SkipTest("This test is known to fail for float64") + + waveform = common_utils.get_whitenoise(sample_rate=44100) + + def func(tensor): + sample_rate = 44100 + gain = 40. + central_freq = 1000. + q = 0.707 + return F.treble_biquad(tensor, sample_rate, gain, central_freq, q) + + self._assert_consistency(func, waveform) + + def test_bass(self): + if self.dtype == torch.float64: + raise unittest.SkipTest("This test is known to fail for float64") + + waveform = common_utils.get_whitenoise(sample_rate=44100) + + def func(tensor): + sample_rate = 44100 + gain = 40. + central_freq = 1000. + q = 0.707 + return F.bass_biquad(tensor, sample_rate, gain, central_freq, q) + + self._assert_consistency(func, waveform) + + def test_deemph(self): + if self.dtype == torch.float64: + raise unittest.SkipTest("This test is known to fail for float64") + + waveform = common_utils.get_whitenoise(sample_rate=44100) + + def func(tensor): + sample_rate = 44100 + return F.deemph_biquad(tensor, sample_rate) + + self._assert_consistency(func, waveform) + + def test_riaa(self): + if self.dtype == torch.float64: + raise unittest.SkipTest("This test is known to fail for float64") + + waveform = common_utils.get_whitenoise(sample_rate=44100) + + def func(tensor): + sample_rate = 44100 + return F.riaa_biquad(tensor, sample_rate) + + self._assert_consistency(func, waveform) + + def test_equalizer(self): + if self.dtype == torch.float64: + raise unittest.SkipTest("This test is known to fail for float64") + + waveform = common_utils.get_whitenoise(sample_rate=44100) + + def func(tensor): + sample_rate = 44100 + center_freq = 300. + gain = 1. + q = 0.707 + return F.equalizer_biquad(tensor, sample_rate, center_freq, gain, q) + + self._assert_consistency(func, waveform) + + def test_perf_biquad_filtering(self): + if self.dtype == torch.float64: + raise unittest.SkipTest("This test is known to fail for float64") + + waveform = common_utils.get_whitenoise() + + def func(tensor): + a = torch.tensor([0.7, 0.2, 0.6], device=tensor.device, dtype=tensor.dtype) + b = torch.tensor([0.4, 0.2, 0.9], device=tensor.device, dtype=tensor.dtype) + return F.lfilter(tensor, a, b) + + self._assert_consistency(func, waveform) + + def test_sliding_window_cmn(self): + def func(tensor): + cmn_window = 600 + min_cmn_window = 100 + center = False + norm_vars = False + a = torch.tensor( + [ + [ + -1.915875792503357, + 1.147700309753418 + ], + [ + 1.8242558240890503, + 1.3869990110397339 + ] + ], + device=tensor.device, + dtype=tensor.dtype + ) + return F.sliding_window_cmn(a, cmn_window, min_cmn_window, center, norm_vars) + b = torch.tensor( + [ + [ + -1.8701, + -0.1196 + ], + [ + 1.8701, + 0.1196 + ] + ] + ) + self._assert_consistency(func, b) + + def test_contrast(self): + waveform = common_utils.get_whitenoise() + + def func(tensor): + enhancement_amount = 80. + return F.contrast(tensor, enhancement_amount) + + self._assert_consistency(func, waveform) + + def test_dcshift(self): + waveform = common_utils.get_whitenoise() + + def func(tensor): + shift = 0.5 + limiter_gain = 0.05 + return F.dcshift(tensor, shift, limiter_gain) + + self._assert_consistency(func, waveform) + + def test_overdrive(self): + waveform = common_utils.get_whitenoise() + + def func(tensor): + gain = 30. + colour = 50. + return F.overdrive(tensor, gain, colour) + + self._assert_consistency(func, waveform) + + def test_phaser(self): + waveform = common_utils.get_whitenoise(sample_rate=44100) + + def func(tensor): + gain_in = 0.5 + gain_out = 0.8 + delay_ms = 2.0 + decay = 0.4 + speed = 0.5 + sample_rate = 44100 + return F.phaser(tensor, sample_rate, gain_in, gain_out, delay_ms, decay, speed, sinusoidal=True) + + self._assert_consistency(func, waveform) + + def test_flanger(self): + torch.random.manual_seed(40) + waveform = torch.rand(2, 100) - 0.5 + + def func(tensor): + delay = 0.8 + depth = 0.88 + regen = 3.0 + width = 0.23 + speed = 1.3 + phase = 60. + sample_rate = 44100 + return F.flanger(tensor, sample_rate, delay, depth, regen, width, speed, + phase, modulation='sinusoidal', interpolation='linear') + + self._assert_consistency(func, waveform) + + def test_spectral_centroid(self): + + def func(tensor): + sample_rate = 44100 + n_fft = 400 + ws = 400 + hop = 200 + pad = 0 + window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype) + return F.spectral_centroid(tensor, sample_rate, pad, window, n_fft, hop, ws) + + tensor = common_utils.get_whitenoise(sample_rate=44100) + self._assert_consistency(func, tensor) + + @common_utils.skipIfNoKaldi + def test_compute_kaldi_pitch(self): + if self.dtype != torch.float32 or self.device != torch.device('cpu'): + raise unittest.SkipTest("Only float32, cpu is supported.") + + def func(tensor): + sample_rate: float = 44100. + return F.compute_kaldi_pitch(tensor, sample_rate) + + tensor = common_utils.get_whitenoise(sample_rate=44100) + self._assert_consistency(func, tensor) + + def test_resample_sinc(self): + def func(tensor): + sr1, sr2 = 16000, 8000 + return F.resample(tensor, sr1, sr2, resampling_method="sinc_interpolation") + + tensor = common_utils.get_whitenoise(sample_rate=16000) + self._assert_consistency(func, tensor) + + def test_resample_kaiser(self): + def func(tensor): + sr1, sr2 = 16000, 8000 + return F.resample(tensor, sr1, sr2, resampling_method="kaiser_window") + + def func_beta(tensor): + sr1, sr2 = 16000, 8000 + beta = 6. + return F.resample(tensor, sr1, sr2, resampling_method="kaiser_window", beta=beta) + + tensor = common_utils.get_whitenoise(sample_rate=16000) + self._assert_consistency(func, tensor) + self._assert_consistency(func_beta, tensor) + + @parameterized.expand([(True, ), (False, )]) + def test_phase_vocoder(self, test_paseudo_complex): + def func(tensor): + is_complex = tensor.is_complex() + + n_freq = tensor.size(-2 if is_complex else -3) + rate = 0.5 + hop_length = 256 + phase_advance = torch.linspace( + 0, + 3.14 * hop_length, + n_freq, + dtype=(torch.real(tensor) if is_complex else tensor).dtype, + device=tensor.device, + )[..., None] + return F.phase_vocoder(tensor, rate, phase_advance) + + tensor = torch.view_as_complex(torch.randn(2, 1025, 400, 2)) + self._assert_consistency_complex(func, tensor, test_paseudo_complex) + + +class FunctionalFloat32Only(TestBaseMixin): + def test_rnnt_loss(self): + def func(tensor): + targets = torch.tensor([[1, 2]], device=tensor.device, dtype=torch.int32) + logit_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32) + target_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32) + return F.rnnt_loss(tensor, targets, logit_lengths, target_lengths) + + logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.6, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.8, 0.1]], + [[0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.1, 0.1], + [0.7, 0.1, 0.2, 0.1, 0.1]]]]) + tensor = logits.to(device=self.device, dtype=torch.float32) + self._assert_consistency(func, tensor) diff --git a/test/torchaudio_unittest/kaldi_io_test.py b/test/torchaudio_unittest/kaldi_io_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4ceeeebbc5de895a1ab3a54d66767839327aae5e --- /dev/null +++ b/test/torchaudio_unittest/kaldi_io_test.py @@ -0,0 +1,33 @@ +import torch +import torchaudio.kaldi_io as kio + +from torchaudio_unittest import common_utils + + +class Test_KaldiIO(common_utils.TorchaudioTestCase): + data1 = [[1, 2, 3], [11, 12, 13], [21, 22, 23]] + data2 = [[31, 32, 33], [41, 42, 43], [51, 52, 53]] + + def _test_helper(self, file_name, expected_data, fn, expected_dtype): + """ Takes a file_name to the input data and a function fn to extract the + data. It compares the extracted data to the expected_data. The expected_dtype + will be used to check that the extracted data is of the right type. + """ + test_filepath = common_utils.get_asset_path(file_name) + expected_output = {'key' + str(idx + 1): torch.tensor(val, dtype=expected_dtype) + for idx, val in enumerate(expected_data)} + + for key, vec in fn(test_filepath): + self.assertTrue(key in expected_output) + self.assertTrue(isinstance(vec, torch.Tensor)) + self.assertEqual(vec.dtype, expected_dtype) + self.assertTrue(torch.all(torch.eq(vec, expected_output[key]))) + + def test_read_vec_int_ark(self): + self._test_helper("vec_int.ark", self.data1, kio.read_vec_int_ark, torch.int32) + + def test_read_vec_flt_ark(self): + self._test_helper("vec_flt.ark", self.data1, kio.read_vec_flt_ark, torch.float32) + + def test_read_mat_ark(self): + self._test_helper("mat.ark", [self.data1, self.data2], kio.read_mat_ark, torch.float32) diff --git a/test/torchaudio_unittest/models/__init__.py b/test/torchaudio_unittest/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/torchaudio_unittest/models/models_test.py b/test/torchaudio_unittest/models/models_test.py new file mode 100644 index 0000000000000000000000000000000000000000..18f3085e53ebd5bc84b7c348e1e638a99848d9c1 --- /dev/null +++ b/test/torchaudio_unittest/models/models_test.py @@ -0,0 +1,248 @@ +import itertools +from collections import namedtuple + +import torch +from parameterized import parameterized +from torchaudio.models import ConvTasNet, DeepSpeech, Wav2Letter, WaveRNN +from torchaudio.models.wavernn import MelResNet, UpsampleNetwork +from torchaudio_unittest import common_utils +from torchaudio_unittest.common_utils import torch_script + + +class TestWav2Letter(common_utils.TorchaudioTestCase): + + def test_waveform(self): + batch_size = 2 + num_features = 1 + num_classes = 40 + input_length = 320 + + model = Wav2Letter(num_classes=num_classes, num_features=num_features) + + x = torch.rand(batch_size, num_features, input_length) + out = model(x) + + assert out.size() == (batch_size, num_classes, 2) + + def test_mfcc(self): + batch_size = 2 + num_features = 13 + num_classes = 40 + input_length = 2 + + model = Wav2Letter(num_classes=num_classes, input_type="mfcc", num_features=num_features) + + x = torch.rand(batch_size, num_features, input_length) + out = model(x) + + assert out.size() == (batch_size, num_classes, 2) + + +class TestMelResNet(common_utils.TorchaudioTestCase): + + def test_waveform(self): + """Validate the output dimensions of a MelResNet block. + """ + + n_batch = 2 + n_time = 200 + n_freq = 100 + n_output = 128 + n_res_block = 10 + n_hidden = 128 + kernel_size = 5 + + model = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size) + + x = torch.rand(n_batch, n_freq, n_time) + out = model(x) + + assert out.size() == (n_batch, n_output, n_time - kernel_size + 1) + + +class TestUpsampleNetwork(common_utils.TorchaudioTestCase): + + def test_waveform(self): + """Validate the output dimensions of a UpsampleNetwork block. + """ + + upsample_scales = [5, 5, 8] + n_batch = 2 + n_time = 200 + n_freq = 100 + n_output = 256 + n_res_block = 10 + n_hidden = 128 + kernel_size = 5 + + total_scale = 1 + for upsample_scale in upsample_scales: + total_scale *= upsample_scale + + model = UpsampleNetwork(upsample_scales, + n_res_block, + n_freq, + n_hidden, + n_output, + kernel_size) + + x = torch.rand(n_batch, n_freq, n_time) + out1, out2 = model(x) + + assert out1.size() == (n_batch, n_freq, total_scale * (n_time - kernel_size + 1)) + assert out2.size() == (n_batch, n_output, total_scale * (n_time - kernel_size + 1)) + + +class TestWaveRNN(common_utils.TorchaudioTestCase): + + def test_waveform(self): + """Validate the output dimensions of a WaveRNN model. + """ + + upsample_scales = [5, 5, 8] + n_rnn = 512 + n_fc = 512 + n_classes = 512 + hop_length = 200 + n_batch = 2 + n_time = 200 + n_freq = 100 + n_output = 256 + n_res_block = 10 + n_hidden = 128 + kernel_size = 5 + + model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block, + n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output) + + x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1)) + mels = torch.rand(n_batch, 1, n_freq, n_time) + out = model(x, mels) + + assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), n_classes) + + def test_infer_waveform(self): + """Validate the output dimensions of a WaveRNN model's infer method. + """ + + upsample_scales = [5, 5, 8] + n_rnn = 128 + n_fc = 128 + n_classes = 128 + hop_length = 200 + n_batch = 2 + n_time = 50 + n_freq = 25 + n_output = 64 + n_res_block = 2 + n_hidden = 32 + kernel_size = 5 + + model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block, + n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output) + + x = torch.rand(n_batch, n_freq, n_time) + lengths = torch.tensor([n_time, n_time // 2]) + out, waveform_lengths = model.infer(x, lengths) + + assert out.size() == (n_batch, 1, hop_length * n_time) + assert waveform_lengths[0] == hop_length * n_time + assert waveform_lengths[1] == hop_length * n_time // 2 + + def test_torchscript_infer(self): + """Scripted model outputs the same as eager mode""" + + upsample_scales = [5, 5, 8] + n_rnn = 128 + n_fc = 128 + n_classes = 128 + hop_length = 200 + n_batch = 2 + n_time = 50 + n_freq = 25 + n_output = 64 + n_res_block = 2 + n_hidden = 32 + kernel_size = 5 + + model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block, + n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output) + model.eval() + x = torch.rand(n_batch, n_freq, n_time) + torch.random.manual_seed(0) + out_eager = model.infer(x) + torch.random.manual_seed(0) + out_script = torch_script(model).infer(x) + self.assertEqual(out_eager, out_script) + + +_ConvTasNetParams = namedtuple( + '_ConvTasNetParams', + [ + 'enc_num_feats', + 'enc_kernel_size', + 'msk_num_feats', + 'msk_num_hidden_feats', + 'msk_kernel_size', + 'msk_num_layers', + 'msk_num_stacks', + ] +) + + +class TestConvTasNet(common_utils.TorchaudioTestCase): + @parameterized.expand(list(itertools.product( + [2, 3], + [ + _ConvTasNetParams(128, 40, 128, 256, 3, 7, 2), + _ConvTasNetParams(256, 40, 128, 256, 3, 7, 2), + _ConvTasNetParams(512, 40, 128, 256, 3, 7, 2), + _ConvTasNetParams(512, 40, 128, 256, 3, 7, 2), + _ConvTasNetParams(512, 40, 128, 512, 3, 7, 2), + _ConvTasNetParams(512, 40, 128, 512, 3, 7, 2), + _ConvTasNetParams(512, 40, 256, 256, 3, 7, 2), + _ConvTasNetParams(512, 40, 256, 512, 3, 7, 2), + _ConvTasNetParams(512, 40, 256, 512, 3, 7, 2), + _ConvTasNetParams(512, 40, 128, 512, 3, 6, 4), + _ConvTasNetParams(512, 40, 128, 512, 3, 4, 6), + _ConvTasNetParams(512, 40, 128, 512, 3, 8, 3), + _ConvTasNetParams(512, 32, 128, 512, 3, 8, 3), + _ConvTasNetParams(512, 16, 128, 512, 3, 8, 3), + ], + ))) + def test_paper_configuration(self, num_sources, model_params): + """ConvTasNet model works on the valid configurations in the paper""" + batch_size = 32 + num_frames = 8000 + + model = ConvTasNet( + num_sources=num_sources, + enc_kernel_size=model_params.enc_kernel_size, + enc_num_feats=model_params.enc_num_feats, + msk_kernel_size=model_params.msk_kernel_size, + msk_num_feats=model_params.msk_num_feats, + msk_num_hidden_feats=model_params.msk_num_hidden_feats, + msk_num_layers=model_params.msk_num_layers, + msk_num_stacks=model_params.msk_num_stacks, + ) + tensor = torch.rand(batch_size, 1, num_frames) + output = model(tensor) + + assert output.shape == (batch_size, num_sources, num_frames) + + +class TestDeepSpeech(common_utils.TorchaudioTestCase): + + def test_deepspeech(self): + n_batch = 2 + n_feature = 1 + n_channel = 1 + n_class = 40 + n_time = 320 + + model = DeepSpeech(n_feature=n_feature, n_class=n_class) + + x = torch.rand(n_batch, n_channel, n_time, n_feature) + out = model(x) + + assert out.size() == (n_batch, n_time, n_class) diff --git a/test/torchaudio_unittest/models/tacotron2/__init__.py b/test/torchaudio_unittest/models/tacotron2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/torchaudio_unittest/models/tacotron2/model_test_cpu_test.py b/test/torchaudio_unittest/models/tacotron2/model_test_cpu_test.py new file mode 100644 index 0000000000000000000000000000000000000000..612699b6e1deb9ae60f8fdfcff95c2231f88f186 --- /dev/null +++ b/test/torchaudio_unittest/models/tacotron2/model_test_cpu_test.py @@ -0,0 +1,23 @@ +import torch + +from torchaudio_unittest.common_utils import PytorchTestCase +from .model_test_impl import ( + Tacotron2EncoderTests, + Tacotron2DecoderTests, + Tacotron2Tests, +) + + +class TestTacotron2EncoderFloat32CPU(Tacotron2EncoderTests, PytorchTestCase): + dtype = torch.float32 + device = torch.device("cpu") + + +class TestTacotron2DecoderFloat32CPU(Tacotron2DecoderTests, PytorchTestCase): + dtype = torch.float32 + device = torch.device("cpu") + + +class TestTacotron2Float32CPU(Tacotron2Tests, PytorchTestCase): + dtype = torch.float32 + device = torch.device("cpu") diff --git a/test/torchaudio_unittest/models/tacotron2/model_test_gpu_test.py b/test/torchaudio_unittest/models/tacotron2/model_test_gpu_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7a6fdd114216d56520d85207488306cb2b1cdf23 --- /dev/null +++ b/test/torchaudio_unittest/models/tacotron2/model_test_gpu_test.py @@ -0,0 +1,26 @@ +import torch + +from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase +from .model_test_impl import ( + Tacotron2EncoderTests, + Tacotron2DecoderTests, + Tacotron2Tests, +) + + +@skipIfNoCuda +class TestTacotron2EncoderFloat32CUDA(Tacotron2EncoderTests, PytorchTestCase): + dtype = torch.float32 + device = torch.device("cuda") + + +@skipIfNoCuda +class TestTacotron2DecoderFloat32CUDA(Tacotron2DecoderTests, PytorchTestCase): + dtype = torch.float32 + device = torch.device("cuda") + + +@skipIfNoCuda +class TestTacotron2Float32CUDA(Tacotron2Tests, PytorchTestCase): + dtype = torch.float32 + device = torch.device("cuda") diff --git a/test/torchaudio_unittest/models/tacotron2/model_test_impl.py b/test/torchaudio_unittest/models/tacotron2/model_test_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..fb43a1403646c9e7ea9b7e9ed5e2d5ce7ff1fc77 --- /dev/null +++ b/test/torchaudio_unittest/models/tacotron2/model_test_impl.py @@ -0,0 +1,381 @@ +from typing import Tuple +import torch +from torch import Tensor +from torchaudio.models import Tacotron2 +from torchaudio.models.tacotron2 import _Encoder, _Decoder +from torchaudio_unittest.common_utils import TestBaseMixin, torch_script + + +class Tacotron2InferenceWrapper(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, text: Tensor, text_lengths: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + return self.model.infer(text, text_lengths) + + +class Tacotron2DecoderInferenceWrapper(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, memory: Tensor, memory_lengths: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + return self.model.infer(memory, memory_lengths) + + +class TorchscriptConsistencyMixin(TestBaseMixin): + r"""Mixin to provide easy access assert torchscript consistency""" + + def _assert_torchscript_consistency(self, model, tensors): + ts_func = torch_script(model) + + torch.random.manual_seed(40) + output = model(*tensors) + + torch.random.manual_seed(40) + ts_output = ts_func(*tensors) + + self.assertEqual(ts_output, output) + + +class Tacotron2EncoderTests(TorchscriptConsistencyMixin): + + def test_tacotron2_torchscript_consistency(self): + r"""Validate the torchscript consistency of a Encoder.""" + n_batch, n_seq, encoder_embedding_dim = 16, 64, 512 + model = _Encoder(encoder_embedding_dim=encoder_embedding_dim, + encoder_n_convolution=3, + encoder_kernel_size=5).to(self.device).eval() + + x = torch.rand( + n_batch, encoder_embedding_dim, n_seq, device=self.device, dtype=self.dtype + ) + input_lengths = ( + torch.ones(n_batch, device=self.device, dtype=torch.int32) * n_seq + ) + + self._assert_torchscript_consistency(model, (x, input_lengths)) + + def test_encoder_output_shape(self): + r"""Feed tensors with specific shape to Tacotron2 Decoder and validate + that it outputs with a tensor with expected shape. + """ + n_batch, n_seq, encoder_embedding_dim = 16, 64, 512 + model = _Encoder(encoder_embedding_dim=encoder_embedding_dim, + encoder_n_convolution=3, + encoder_kernel_size=5).to(self.device).eval() + + x = torch.rand( + n_batch, encoder_embedding_dim, n_seq, device=self.device, dtype=self.dtype + ) + input_lengths = ( + torch.ones(n_batch, device=self.device, dtype=torch.int32) * n_seq + ) + out = model(x, input_lengths) + + assert out.size() == (n_batch, n_seq, encoder_embedding_dim) + + +def _get_decoder_model(n_mels=80, encoder_embedding_dim=512, + decoder_max_step=2000, gate_threshold=0.5): + model = _Decoder( + n_mels=n_mels, + n_frames_per_step=1, + encoder_embedding_dim=encoder_embedding_dim, + decoder_rnn_dim=1024, + decoder_max_step=decoder_max_step, + decoder_dropout=0.1, + decoder_early_stopping=True, + attention_rnn_dim=1024, + attention_hidden_dim=128, + attention_location_n_filter=32, + attention_location_kernel_size=31, + attention_dropout=0.1, + prenet_dim=256, + gate_threshold=gate_threshold, + ) + return model + + +class Tacotron2DecoderTests(TorchscriptConsistencyMixin): + + def test_decoder_torchscript_consistency(self): + r"""Validate the torchscript consistency of a Decoder.""" + n_batch = 16 + n_mels = 80 + n_seq = 200 + encoder_embedding_dim = 256 + n_time_steps = 150 + + model = _get_decoder_model(n_mels=n_mels, encoder_embedding_dim=encoder_embedding_dim) + model = model.to(self.device).eval() + + memory = torch.rand( + n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device + ) + decoder_inputs = torch.rand( + n_batch, n_mels, n_time_steps, dtype=self.dtype, device=self.device + ) + memory_lengths = torch.ones(n_batch, dtype=torch.int32, device=self.device) + + self._assert_torchscript_consistency( + model, (memory, decoder_inputs, memory_lengths) + ) + + def test_decoder_output_shape(self): + r"""Feed tensors with specific shape to Tacotron2 Decoder and validate + that it outputs with a tensor with expected shape. + """ + n_batch = 16 + n_mels = 80 + n_seq = 200 + encoder_embedding_dim = 256 + n_time_steps = 150 + + model = _get_decoder_model(n_mels=n_mels, encoder_embedding_dim=encoder_embedding_dim) + model = model.to(self.device).eval() + + memory = torch.rand( + n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device + ) + decoder_inputs = torch.rand( + n_batch, n_mels, n_time_steps, dtype=self.dtype, device=self.device + ) + memory_lengths = torch.ones(n_batch, dtype=torch.int32, device=self.device) + + mel_specgram, gate_outputs, alignments = model( + memory, decoder_inputs, memory_lengths + ) + + assert mel_specgram.size() == (n_batch, n_mels, n_time_steps) + assert gate_outputs.size() == (n_batch, n_time_steps) + assert alignments.size() == (n_batch, n_time_steps, n_seq) + + def test_decoder_inference_torchscript_consistency(self): + r"""Validate the torchscript consistency of a Decoder.""" + n_batch = 16 + n_mels = 80 + n_seq = 200 + encoder_embedding_dim = 256 + decoder_max_step = 300 # make inference more efficient + gate_threshold = 0.505 # make inference more efficient + + model = _get_decoder_model( + n_mels=n_mels, + encoder_embedding_dim=encoder_embedding_dim, + decoder_max_step=decoder_max_step, + gate_threshold=gate_threshold, + ) + model = model.to(self.device).eval() + + memory = torch.rand( + n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device + ) + memory_lengths = torch.ones(n_batch, dtype=torch.int32, device=self.device) + + model_wrapper = Tacotron2DecoderInferenceWrapper(model) + + self._assert_torchscript_consistency(model_wrapper, (memory, memory_lengths)) + + def test_decoder_inference_output_shape(self): + r"""Validate the torchscript consistency of a Decoder.""" + n_batch = 16 + n_mels = 80 + n_seq = 200 + encoder_embedding_dim = 256 + decoder_max_step = 300 # make inference more efficient + gate_threshold = 0.505 # if set to 0.5, the model will only run one step + + model = _get_decoder_model( + n_mels=n_mels, + encoder_embedding_dim=encoder_embedding_dim, + decoder_max_step=decoder_max_step, + gate_threshold=gate_threshold, + ) + model = model.to(self.device).eval() + + memory = torch.rand( + n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device + ) + memory_lengths = torch.ones(n_batch, dtype=torch.int32, device=self.device) + + mel_specgram, mel_specgram_lengths, gate_outputs, alignments = model.infer( + memory, memory_lengths + ) + + assert len(mel_specgram.size()) == 3 + assert mel_specgram.size()[:-1] == (n_batch, n_mels, ) + assert mel_specgram.size()[2] == mel_specgram_lengths.max().item() + assert len(mel_specgram_lengths.size()) == 1 + assert mel_specgram_lengths.size()[0] == n_batch + assert mel_specgram_lengths.max().item() <= model.decoder_max_step + assert len(gate_outputs.size()) == 2 + assert gate_outputs.size()[0] == n_batch + assert gate_outputs.size()[1] == mel_specgram_lengths.max().item() + assert len(alignments.size()) == 2 + assert alignments.size()[0] == n_seq + assert alignments.size()[1] == mel_specgram_lengths.max().item() * n_batch + + +def _get_tacotron2_model(n_mels, decoder_max_step=2000, gate_threshold=0.5): + return Tacotron2( + mask_padding=False, + n_mels=n_mels, + n_symbol=148, + n_frames_per_step=1, + symbol_embedding_dim=512, + encoder_embedding_dim=512, + encoder_n_convolution=3, + encoder_kernel_size=5, + decoder_rnn_dim=1024, + decoder_max_step=decoder_max_step, + decoder_dropout=0.1, + decoder_early_stopping=True, + attention_rnn_dim=1024, + attention_hidden_dim=128, + attention_location_n_filter=32, + attention_location_kernel_size=31, + attention_dropout=0.1, + prenet_dim=256, + postnet_n_convolution=5, + postnet_kernel_size=5, + postnet_embedding_dim=512, + gate_threshold=gate_threshold, + ) + + +class Tacotron2Tests(TorchscriptConsistencyMixin): + + def _get_inputs( + self, n_mels: int, n_batch: int, max_mel_specgram_length: int, max_text_length: int + ): + text = torch.randint( + 0, 148, (n_batch, max_text_length), dtype=torch.int32, device=self.device + ) + text_lengths = max_text_length * torch.ones( + (n_batch,), dtype=torch.int32, device=self.device + ) + mel_specgram = torch.rand( + n_batch, + n_mels, + max_mel_specgram_length, + dtype=self.dtype, + device=self.device, + ) + mel_specgram_lengths = max_mel_specgram_length * torch.ones( + (n_batch,), dtype=torch.int32, device=self.device + ) + return text, text_lengths, mel_specgram, mel_specgram_lengths + + def test_tacotron2_torchscript_consistency(self): + r"""Validate the torchscript consistency of a Tacotron2.""" + n_batch = 16 + n_mels = 80 + max_mel_specgram_length = 300 + max_text_length = 100 + + model = _get_tacotron2_model(n_mels).to(self.device).eval() + inputs = self._get_inputs( + n_mels, n_batch, max_mel_specgram_length, max_text_length + ) + + self._assert_torchscript_consistency(model, inputs) + + def test_tacotron2_output_shape(self): + r"""Feed tensors with specific shape to Tacotron2 and validate + that it outputs with a tensor with expected shape. + """ + n_batch = 16 + n_mels = 80 + max_mel_specgram_length = 300 + max_text_length = 100 + + model = _get_tacotron2_model(n_mels).to(self.device).eval() + inputs = self._get_inputs( + n_mels, n_batch, max_mel_specgram_length, max_text_length + ) + mel_out, mel_out_postnet, gate_outputs, alignments = model(*inputs) + + assert mel_out.size() == (n_batch, n_mels, max_mel_specgram_length) + assert mel_out_postnet.size() == (n_batch, n_mels, max_mel_specgram_length) + assert gate_outputs.size() == (n_batch, max_mel_specgram_length) + assert alignments.size() == (n_batch, max_mel_specgram_length, max_text_length) + + def test_tacotron2_backward(self): + r"""Make sure calling the backward function on Tacotron2's outputs does + not error out. Following: + https://github.com/pytorch/vision/blob/23b8760374a5aaed53c6e5fc83a7e83dbe3b85df/test/test_models.py#L255 + """ + n_batch = 16 + n_mels = 80 + max_mel_specgram_length = 300 + max_text_length = 100 + + model = _get_tacotron2_model(n_mels).to(self.device) + inputs = self._get_inputs( + n_mels, n_batch, max_mel_specgram_length, max_text_length + ) + mel_out, mel_out_postnet, gate_outputs, _ = model(*inputs) + + mel_out.sum().backward(retain_graph=True) + mel_out_postnet.sum().backward(retain_graph=True) + gate_outputs.sum().backward() + + def _get_inference_inputs(self, n_batch: int, max_text_length: int): + text = torch.randint( + 0, 148, (n_batch, max_text_length), dtype=torch.int32, device=self.device + ) + text_lengths = max_text_length * torch.ones( + (n_batch,), dtype=torch.int32, device=self.device + ) + return text, text_lengths + + def test_tacotron2_inference_torchscript_consistency(self): + r"""Validate the torchscript consistency of Tacotron2 inference function.""" + n_batch = 16 + n_mels = 40 + max_text_length = 100 + decoder_max_step = 200 # make inference more efficient + gate_threshold = 0.51 # if set to 0.5, the model will only run one step + + model = _get_tacotron2_model( + n_mels, decoder_max_step=decoder_max_step, gate_threshold=gate_threshold + ).to(self.device).eval() + inputs = self._get_inference_inputs(n_batch, max_text_length) + + model_wrapper = Tacotron2InferenceWrapper(model) + + self._assert_torchscript_consistency(model_wrapper, inputs) + + def test_tacotron2_inference_output_shape(self): + r"""Feed tensors with specific shape to Tacotron2 inference function and validate + that it outputs with a tensor with expected shape. + """ + n_batch = 16 + n_mels = 40 + max_text_length = 100 + decoder_max_step = 200 # make inference more efficient + gate_threshold = 0.51 # if set to 0.5, the model will only run one step + + model = _get_tacotron2_model( + n_mels, decoder_max_step=decoder_max_step, gate_threshold=gate_threshold + ).to(self.device).eval() + inputs = self._get_inference_inputs(n_batch, max_text_length) + + mel_out, mel_specgram_lengths, alignments = model.infer(*inputs) + + # There is no guarantee on exactly what max_mel_specgram_length should be + # We only know that it should be smaller than model.decoder.decoder_max_step + assert len(mel_out.size()) == 3 + assert mel_out.size()[:2] == (n_batch, n_mels, ) + assert mel_out.size()[2] == mel_specgram_lengths.max().item() + assert len(mel_specgram_lengths.size()) == 1 + assert mel_specgram_lengths.size()[0] == n_batch + assert mel_specgram_lengths.max().item() <= model.decoder.decoder_max_step + assert len(alignments.size()) == 3 + assert alignments.size()[0] == n_batch + assert alignments.size()[1] == mel_specgram_lengths.max().item() + assert alignments.size()[2] == max_text_length diff --git a/test/torchaudio_unittest/models/wav2vec2/__init__.py b/test/torchaudio_unittest/models/wav2vec2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/torchaudio_unittest/models/wav2vec2/fairseq_integration_test.py b/test/torchaudio_unittest/models/wav2vec2/fairseq_integration_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3e3b869a303c80563ddc473b1078866064b9674f --- /dev/null +++ b/test/torchaudio_unittest/models/wav2vec2/fairseq_integration_test.py @@ -0,0 +1,240 @@ +import json + +import torch +from torchaudio.models.wav2vec2 import ( + wav2vec2_base, + wav2vec2_large, + wav2vec2_large_lv60k, + hubert_base, + hubert_large, + hubert_xlarge, +) +from torchaudio.models.wav2vec2.utils import ( + import_fairseq_model, +) +from parameterized import parameterized + +from torchaudio_unittest.common_utils import ( + get_asset_path, + skipIfNoModule, + TorchaudioTestCase, +) + + +def _load_config(*paths): + with open(f'{get_asset_path("wav2vec2", "fairseq", *paths)}.json', 'r') as file_: + return json.load(file_) + + +def _name_func(testcase_func, i, param): + return f'{testcase_func.__name__}_{i}_{param[0][1].__name__}' + + +# Pretraining models +WAV2VEC2_BASE = _load_config('wav2vec_small') +WAV2VEC2_LARGE = _load_config('libri960_big') +WAV2VEC2_LARGE_LV60K = _load_config('wav2vec_vox_new') +WAV2VEC2_XLSR_53_56K = _load_config('xlsr_53_56k') +HUBERT_BASE = _load_config('hubert_base_ls960') +HUBERT_LARGE_LL60K = _load_config('hubert_large_ll60k') +HUBERT_XLARGE_LL60K = _load_config('hubert_xtralarge_ll60k') +# Finetuning models +WAV2VEC2_BASE_960H = _load_config('wav2vec_small_960h') +WAV2VEC2_LARGE_960H = _load_config('wav2vec_large_960h') +WAV2VEC2_LARGE_LV60K_960H = _load_config('wav2vec_large_lv60k_960h') +WAV2VEC2_LARGE_LV60K_SELF_960H = _load_config('wav2vec_large_lv60k_self_960h') +HUBERT_LARGE = _load_config('hubert_large_ll60k_finetune_ls960') +HUBERT_XLARGE = _load_config('hubert_xtralarge_ll60k_finetune_ls960') + + +# Config and corresponding factory functions +WAV2VEC2_PRETRAINING_CONFIGS = parameterized.expand([ + (WAV2VEC2_BASE, wav2vec2_base), + (WAV2VEC2_LARGE, wav2vec2_large), + (WAV2VEC2_LARGE_LV60K, wav2vec2_large_lv60k), + (WAV2VEC2_XLSR_53_56K, wav2vec2_large_lv60k), +], name_func=_name_func) +HUBERT_PRETRAINING_CONFIGS = parameterized.expand([ + (HUBERT_BASE, hubert_base), + (HUBERT_LARGE_LL60K, hubert_large), + (HUBERT_XLARGE_LL60K, hubert_xlarge), +], name_func=_name_func) +ALL_PRETRAINING_CONFIGS = parameterized.expand([ + (WAV2VEC2_BASE, wav2vec2_base), + (WAV2VEC2_LARGE, wav2vec2_large), + (WAV2VEC2_LARGE_LV60K, wav2vec2_large_lv60k), + (WAV2VEC2_XLSR_53_56K, wav2vec2_large_lv60k), + (HUBERT_BASE, hubert_base), + (HUBERT_LARGE_LL60K, hubert_large), + (HUBERT_XLARGE_LL60K, hubert_xlarge), +], name_func=_name_func) +FINETUNING_CONFIGS = parameterized.expand([ + (WAV2VEC2_BASE_960H, wav2vec2_base), + (WAV2VEC2_LARGE_960H, wav2vec2_large), + (WAV2VEC2_LARGE_LV60K_960H, wav2vec2_large_lv60k), + (WAV2VEC2_LARGE_LV60K_SELF_960H, wav2vec2_large_lv60k), + (HUBERT_LARGE, hubert_large), + (HUBERT_XLARGE, hubert_xlarge), +], name_func=_name_func) + + +@skipIfNoModule('fairseq') +class TestFairseqIntegration(TorchaudioTestCase): + """Test the process of importing the models from fairseq. + + Test methods in this test suite check the following things + 1. Models loaded with fairseq cane be imported. + 2. The same model can be recreated without fairseq. + """ + def _get_model(self, config, num_out=None): + import copy + from omegaconf import OmegaConf + from fairseq.models.wav2vec.wav2vec2 import ( + Wav2Vec2Config, + Wav2Vec2Model, + ) + from fairseq.models.wav2vec.wav2vec2_asr import ( + Wav2VecEncoder, + Wav2Vec2CtcConfig, + ) + from fairseq.models.hubert.hubert_asr import ( + HubertCtcConfig, + HubertEncoder, + ) + from fairseq.models.hubert.hubert import ( + HubertModel, + HubertConfig, + ) + from fairseq.tasks.hubert_pretraining import HubertPretrainingConfig + + if config['_name'] == 'wav2vec_ctc': + config = copy.deepcopy(config) + config['w2v_args'] = OmegaConf.create(config['w2v_args']) + return Wav2VecEncoder(Wav2Vec2CtcConfig(**config), num_out) + if config['_name'] == 'wav2vec2': + return Wav2Vec2Model(Wav2Vec2Config(**config)) + if config['_name'] == 'hubert_ctc': + config = copy.deepcopy(config) + config['w2v_args'] = OmegaConf.create(config['w2v_args']) + ctc_cfg = HubertCtcConfig(**config) + return HubertEncoder(ctc_cfg, tgt_dict=range(num_out)) + if config['_name'] == 'hubert': + dicts = [list(range(i)) for i in config['num_classes']] + return HubertModel( + HubertConfig(**config['model']), + HubertPretrainingConfig(**config['task']), + dicts, + ) + raise ValueError(f'Unexpected configuration: {config["_name"]}') + + @WAV2VEC2_PRETRAINING_CONFIGS + def test_import_wave2vec2_pretraining_model(self, config, _): + """Wav2vec2 pretraining models from fairseq can be imported and yields the same results""" + batch_size, num_frames = 3, 1024 + + torch.manual_seed(0) + original = self._get_model(config).eval() + imported = import_fairseq_model(original).eval() + + x = torch.randn(batch_size, num_frames) + hyp, _ = imported.extract_features(x) + refs = original.extract_features(x, padding_mask=torch.zeros_like(x), layer=-1) + for i, (ref, _) in enumerate(refs['layer_results']): + self.assertEqual(hyp[i], ref.transpose(0, 1)) + + @HUBERT_PRETRAINING_CONFIGS + def test_import_hubert_pretraining_model(self, config, factory_func): + """HuBERT pretraining models from fairseq can be imported and yields the same results""" + batch_size, num_frames = 3, 1024 + + torch.manual_seed(0) + original = self._get_model(config).eval() + imported = import_fairseq_model(original).eval() + + x = torch.randn(batch_size, num_frames) + mask = torch.zeros_like(x) + hyp, _ = imported.extract_features(x) + + # check the last layer + ref, _ = original.extract_features(x, padding_mask=mask, output_layer=len(original.encoder.layers)) + atol = 3.0e-05 if factory_func is hubert_xlarge else 1.0e-5 + self.assertEqual(hyp[-1], ref, atol=atol, rtol=1.3e-6) + + # check the first layer + ref, _ = original.extract_features(x, padding_mask=mask, output_layer=1) + self.assertEqual(hyp[0], ref) + + @ALL_PRETRAINING_CONFIGS + def test_recreate_pretraining_model(self, config, factory_func): + """Imported pretraining models can be recreated via a factory function without fairseq.""" + batch_size, num_frames = 3, 1024 + + original = self._get_model(config).eval() + imported = import_fairseq_model(original).eval() + + reloaded = factory_func() + reloaded.load_state_dict(imported.state_dict()) + reloaded.eval() + + x = torch.randn(batch_size, num_frames) + lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ]) + # Without mask + ref, _ = imported(x) + hyp, _ = reloaded(x) + self.assertEqual(ref, hyp) + + # With mask + ref, ref_lengths = imported(x, lengths) + hyp, hyp_lengths = reloaded(x, lengths) + self.assertEqual(ref, hyp) + self.assertEqual(ref_lengths, hyp_lengths) + + @FINETUNING_CONFIGS + def test_import_finetuning_model(self, config, _): + """Fintuned wav2vec2 models from fairseq can be imported and yields the same results""" + num_out = 28 + batch_size, num_frames = 3, 1024 + + original = self._get_model(config, num_out).eval() + imported = import_fairseq_model(original).eval() + + # Without mask + x = torch.randn(batch_size, num_frames) + ref = original(x, torch.zeros_like(x))['encoder_out'].transpose(0, 1) + hyp, _ = imported(x) + self.assertEqual(ref, hyp) + + # With mask + lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ]) + mask = torch.arange(num_frames).expand(batch_size, num_frames) >= lengths[:, None] + ref = original(x, mask)['encoder_out'].transpose(0, 1) + hyp, output_lengths = imported(x, lengths) + for i, l in enumerate(output_lengths): + self.assertEqual(ref[i, :l, ...], hyp[i, :l, ...]) + + @FINETUNING_CONFIGS + def test_recreate_finetuning_model(self, config, factory_func): + """Imported finetuning models can be recreated via a factory function without fairseq.""" + num_out = 28 + batch_size, num_frames = 3, 1024 + + original = self._get_model(config, num_out).eval() + imported = import_fairseq_model(original).eval() + + reloaded = factory_func(aux_num_out=num_out) + reloaded.load_state_dict(imported.state_dict()) + reloaded.eval() + + # Without mask + torch.manual_seed(0) + x = torch.randn(batch_size, num_frames) + ref, _ = imported(x) + hyp, _ = reloaded(x) + self.assertEqual(ref, hyp) + + # With mask + lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ]) + ref, ref_lengths = imported(x, lengths) + hyp, hyp_lengths = reloaded(x, lengths) + self.assertEqual(ref, hyp) + self.assertEqual(ref_lengths, hyp_lengths) diff --git a/test/torchaudio_unittest/models/wav2vec2/huggingface_intergration_test.py b/test/torchaudio_unittest/models/wav2vec2/huggingface_intergration_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ff177b7cfae78a932762f1a017e8048a676935cc --- /dev/null +++ b/test/torchaudio_unittest/models/wav2vec2/huggingface_intergration_test.py @@ -0,0 +1,224 @@ +import json + +import torch +from torchaudio.models.wav2vec2 import ( + wav2vec2_base, + wav2vec2_large, + wav2vec2_large_lv60k, +) +from torchaudio.models.wav2vec2.utils import import_huggingface_model +from parameterized import parameterized + +from torchaudio_unittest.common_utils import ( + get_asset_path, + skipIfNoModule, + TorchaudioTestCase, +) + + +def _load_config(*paths): + with open(f'{get_asset_path("wav2vec2", "huggingface", *paths)}.json', 'r') as file_: + return json.load(file_) + + +def _name_func(testcase_func, i, param): + return f"{testcase_func.__name__}_{i}_{param[0][1].__name__}" + + +# Pretrained +HF_BASE = _load_config('facebook', 'wav2vec2-base') +HF_LARGE = _load_config('facebook', 'wav2vec2-large') +HF_LARGE_LV60 = _load_config('facebook', 'wav2vec2-large-lv60') +HF_LARGE_XLSR_53 = _load_config('facebook', 'wav2vec2-large-xlsr-53') +HF_BASE_10K_VOXPOPULI = _load_config('facebook', 'wav2vec2-base-10k-voxpopuli') +# Finetuned +HF_BASE_960H = _load_config('facebook', 'wav2vec2-base-960h') +HF_LARGE_960H = _load_config('facebook', 'wav2vec2-large-960h') +HF_LARGE_LV60_960H = _load_config('facebook', 'wav2vec2-large-960h-lv60') +HF_LARGE_LV60_SELF_960H = _load_config('facebook', 'wav2vec2-large-960h-lv60-self') +HF_LARGE_XLSR_DE = _load_config('facebook', 'wav2vec2-large-xlsr-53-german') + +# Config and corresponding factory functions +PRETRAIN_CONFIGS = parameterized.expand([ + (HF_BASE, wav2vec2_base), + (HF_LARGE, wav2vec2_large), + (HF_LARGE_LV60, wav2vec2_large_lv60k), + (HF_LARGE_XLSR_53, wav2vec2_large_lv60k), + (HF_BASE_10K_VOXPOPULI, wav2vec2_base), +], name_func=_name_func) +FINETUNE_CONFIGS = parameterized.expand([ + (HF_BASE_960H, wav2vec2_base), + (HF_LARGE_960H, wav2vec2_large), + (HF_LARGE_LV60_960H, wav2vec2_large_lv60k), + (HF_LARGE_LV60_SELF_960H, wav2vec2_large_lv60k), + (HF_LARGE_XLSR_DE, wav2vec2_large_lv60k), +], name_func=_name_func) + + +@skipIfNoModule('transformers') +class TestHFIntegration(TorchaudioTestCase): + """Test the process of importing the models from Hugging Face Transformers + + Test methods in this test suite check the following things + 1. Models loaded with Hugging Face Transformers cane be imported. + 2. The same model can be recreated without Hugging Face Transformers. + """ + def _get_model(self, config): + # Helper function to avoid importing transformers on module scope. + # Normally, we use `is_module_available` helper function to check if + # the library is available, and import it on module scope if available. + # However, somehow, once "transformers" is imported, `is_module_available` + # starts to fail. Therefore, we defer importing "transformers" until + # the actual tests are started. + from transformers.models.wav2vec2 import ( + Wav2Vec2Config, + Wav2Vec2Model, + Wav2Vec2ForCTC, + ) + if config['architectures'] == ['Wav2Vec2Model']: + return Wav2Vec2Model(Wav2Vec2Config(**config)) + if config['architectures'] == ['Wav2Vec2ForCTC']: + return Wav2Vec2ForCTC(Wav2Vec2Config(**config)) + raise ValueError(f'Unexpected arch: {config["architectures"]}') + + def _test_import_pretrain(self, original, imported, config): + torch.manual_seed(0) + # FeatureExtractor + x = torch.randn(3, 1024) + ref = original.feature_extractor(x).transpose(1, 2) + hyp, _ = imported.feature_extractor(x, None) + self.assertEqual(ref, hyp) + # Feature projection + x = torch.randn(3, 10, config['conv_dim'][-1]) + ref = original.feature_projection(x)[0] + hyp = imported.encoder.feature_projection(x) + self.assertEqual(ref, hyp) + # Convolutional Positional Encoder + x = torch.randn(3, 256, config['hidden_size']) + ref = original.encoder.pos_conv_embed(x) + hyp = imported.encoder.transformer.pos_conv_embed(x) + self.assertEqual(ref, hyp) + # Encoder Transformer Layer + for original_, imported_ in zip(original.encoder.layers, imported.encoder.transformer.layers): + b, l, e = 16, 3, config["hidden_size"] + x = torch.randn(b, l, e) + mask = torch.randn(b, 1, l, l) + + ref, = original_(x, attention_mask=mask, output_attentions=False) + hyp = imported_(x, mask) + self.assertEqual(ref, hyp) + # The whole Encoder Transformer + b, l, e = 16, 3, config["hidden_size"] + x = torch.randn(b, l, e) + ref = original.encoder(x).last_hidden_state + hyp = imported.encoder.transformer(x) + self.assertEqual(ref, hyp) + + def _test_import_finetune(self, original, imported, config): + # Aux + x = torch.randn(3, 10, config["hidden_size"]) + ref = original.lm_head(x) + hyp = imported.aux(x) + self.assertEqual(ref, hyp) + # The whole model without mask + x = torch.randn(3, 1024) + ref = original(x).logits + hyp, _ = imported(x) + self.assertEqual(ref, hyp) + # The whole model without mask + batch_size, num_frames = 3, 1024 + x = torch.randn(batch_size, num_frames) + ref = original(x).logits + hyp, _ = imported(x) + self.assertEqual(ref, hyp) + + # The whole model with mask + batch_size, num_frames = 3, 1024 + x = torch.randn(batch_size, num_frames) + lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ]) + mask = torch.arange(num_frames).expand(batch_size, num_frames) < lengths[:, None] + + ref = original(x, attention_mask=mask).logits + hyp, output_lengths = imported(x, lengths) + + for i, l in enumerate(output_lengths): + self.assertEqual(ref[i, :l, ...], hyp[i, :l, ...]) + + @PRETRAIN_CONFIGS + def test_import_pretrain(self, config, _): + """wav2vec2 models from HF transformers can be imported and yields the same results""" + original = self._get_model(config).eval() + imported = import_huggingface_model(original).eval() + self._test_import_pretrain(original, imported, config) + + @FINETUNE_CONFIGS + def test_import_finetune(self, config, _): + """wav2vec2 models from HF transformers can be imported and yields the same results""" + original = self._get_model(config).eval() + imported = import_huggingface_model(original).eval() + self._test_import_pretrain(original.wav2vec2, imported, config) + self._test_import_finetune(original, imported, config) + + def _test_recreate(self, imported, reloaded, config): + torch.manual_seed(0) + # FeatureExtractor + x = torch.randn(3, 1024) + ref, _ = imported.feature_extractor(x, None) + hyp, _ = reloaded.feature_extractor(x, None) + self.assertEqual(ref, hyp) + # Feature projection + x = torch.randn(3, 10, config['conv_dim'][-1]) + ref = imported.encoder.feature_projection(x) + hyp = reloaded.encoder.feature_projection(x) + self.assertEqual(ref, hyp) + # Convolutional Positional Encoder + x = torch.randn(3, 256, config['hidden_size']) + ref = imported.encoder.transformer.pos_conv_embed(x) + hyp = reloaded.encoder.transformer.pos_conv_embed(x) + self.assertEqual(ref, hyp) + # Encoder Transformer Layer + for imported_, reloaded_ in zip(imported.encoder.transformer.layers, reloaded.encoder.transformer.layers): + b, l, e = 16, 3, config["hidden_size"] + x = torch.randn(b, l, e) + mask = torch.randn(b, 1, l, l) + + ref = imported_(x, mask) + hyp = reloaded_(x, mask) + self.assertEqual(ref, hyp) + # The whole Encoder Transformer + # TODO: Add mask pattern. Expected mask shapes and values are different. + b, l, e = 16, 3, config["hidden_size"] + x = torch.randn(b, l, e) + mask = torch.randn(b, 1, l, l) + ref = imported.encoder.transformer(x) + hyp = reloaded.encoder.transformer(x) + self.assertEqual(ref, hyp) + # Aux + if imported.aux is not None: + x = torch.randn(3, 10, config["hidden_size"]) + ref = imported.aux(x) + hyp = reloaded.aux(x) + self.assertEqual(ref, hyp) + # The whole model + x = torch.randn(3, 1024) + ref, _ = imported(x) + hyp, _ = reloaded(x) + self.assertEqual(ref, hyp) + + @PRETRAIN_CONFIGS + def test_recreate_pretrain(self, config, factory_func): + """Imported models can be recreated via a factory function without Hugging Face transformers.""" + imported = import_huggingface_model(self._get_model(config)).eval() + reloaded = factory_func() + reloaded.load_state_dict(imported.state_dict()) + reloaded.eval() + self._test_recreate(imported, reloaded, config) + + @FINETUNE_CONFIGS + def test_recreate_finetune(self, config, factory_func): + """Imported models can be recreated via a factory function without Hugging Face transformers.""" + imported = import_huggingface_model(self._get_model(config)).eval() + reloaded = factory_func(aux_num_out=imported.aux.out_features) + reloaded.load_state_dict(imported.state_dict()) + reloaded.eval() + self._test_recreate(imported, reloaded, config) diff --git a/test/torchaudio_unittest/models/wav2vec2/model_test.py b/test/torchaudio_unittest/models/wav2vec2/model_test.py new file mode 100644 index 0000000000000000000000000000000000000000..22a707d8d4424c5abc398d186db0f616f25dc5ca --- /dev/null +++ b/test/torchaudio_unittest/models/wav2vec2/model_test.py @@ -0,0 +1,243 @@ +import os + +import torch +import torch.nn.functional as F + +from torchaudio.models.wav2vec2 import ( + wav2vec2_base, + wav2vec2_large, + wav2vec2_large_lv60k, + hubert_base, + hubert_large, + hubert_xlarge, +) +from torchaudio_unittest.common_utils import ( + TorchaudioTestCase, + skipIfNoQengine, + skipIfNoCuda, + torch_script, +) +from parameterized import parameterized + + +def _name_func(testcase_func, i, param): + return f"{testcase_func.__name__}_{i}_{param[0][0].__name__}" + + +factory_funcs = parameterized.expand([ + (wav2vec2_base, ), + (wav2vec2_large, ), + (wav2vec2_large_lv60k, ), + (hubert_base, ), + (hubert_large, ), + (hubert_xlarge, ), +], name_func=_name_func) + + +class TestWav2Vec2Model(TorchaudioTestCase): + def _smoke_test(self, model, device, dtype): + model = model.to(device=device, dtype=dtype) + model = model.eval() + + torch.manual_seed(0) + batch_size, num_frames = 3, 1024 + + waveforms = torch.randn( + batch_size, num_frames, device=device, dtype=dtype) + lengths = torch.randint( + low=0, high=num_frames, size=[batch_size, ], device=device) + + model(waveforms, lengths) + + @parameterized.expand([(torch.float32, ), (torch.float64, )]) + def test_cpu_smoke_test(self, dtype): + model = wav2vec2_base() + self._smoke_test(model, torch.device('cpu'), dtype) + model = wav2vec2_base(aux_num_out=32) + self._smoke_test(model, torch.device('cpu'), dtype) + + @parameterized.expand([(torch.float32, ), (torch.float64, )]) + @skipIfNoCuda + def test_cuda_smoke_test(self, dtype): + model = wav2vec2_base() + self._smoke_test(model, torch.device('cuda'), dtype) + model = wav2vec2_base(aux_num_out=32) + self._smoke_test(model, torch.device('cuda'), dtype) + + def _feature_extractor_test(self, model): + batch_size, num_frames = 3, 1024 + + model.eval() + num_layers = len(model.encoder.transformer.layers) + + torch.manual_seed(0) + waveforms = torch.randn(batch_size, num_frames) + lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ]) + + # Not providing num_layers returns all the intermediate features from + # tranformer layers + all_features, lengths_ = model.extract_features(waveforms, lengths, num_layers=None) + assert len(all_features) == num_layers + for features in all_features: + assert features.ndim == 3 + assert features.shape[0] == batch_size + assert lengths_.shape == torch.Size([batch_size]) + + # Limiting the number of layers to `l`. + for l in range(1, num_layers + 1): + features, lengths_ = model.extract_features(waveforms, lengths, num_layers=l) + assert len(features) == l + for i in range(l): + self.assertEqual(all_features[i], features[i]) + assert lengths_.shape == torch.Size([batch_size]) + + @factory_funcs + def test_extract_feature(self, factory_func): + """`extract_features` method does not fail""" + self._feature_extractor_test(factory_func(aux_num_out=32)) + + def _test_batch_consistency(self, model): + model.eval() + batch_size, max_frames = 5, 5 * 1024 + torch.manual_seed(0) + waveforms = torch.randn(batch_size, max_frames) + input_lengths = torch.tensor([i * 3200 for i in range(1, 6)]) + + # Batch process with lengths + batch_logits, output_lengths = model(waveforms, input_lengths) + for i in range(batch_size): + # Par-sample process without feeding length + single_logit, _ = model(waveforms[i:i + 1, :input_lengths[i]], None) + batch_logit = batch_logits[i:i + 1, :output_lengths[i]] + + # Convert to probability so that it's easier to interpretate the diff + single_prob = F.softmax(single_logit, dim=2) + batch_prob = F.softmax(batch_logit, dim=2) + # We allow max atol=0.005 -> 0.5% + self.assertEqual(single_prob, batch_prob, atol=0.005, rtol=0) + + @factory_funcs + def test_pretrain_batch_consistency(self, factory_func): + """Results from single process and batched process should be reasonably close + """ + self._test_batch_consistency(factory_func()) + + @factory_funcs + def test_finetune_batch_consistency(self, factory_func): + """Results from single process and batched process should be reasonably close + """ + self._test_batch_consistency(factory_func(aux_num_out=32)) + + def _test_zero_length(self, model): + model.eval() + torch.manual_seed(0) + batch_size = 3 + waveforms = torch.randn(batch_size, 1024) + input_lengths = torch.zeros(batch_size) + _, output_lengths = model(waveforms, input_lengths) + self.assertEqual(torch.zeros_like(output_lengths), output_lengths) + _, output_lengths = model.extract_features(waveforms, input_lengths) + self.assertEqual(torch.zeros_like(output_lengths), output_lengths) + + @factory_funcs + def test_pretrain_zero_length(self, factory_func): + """Passing zero length should not fail""" + self._test_zero_length(factory_func()) + + @factory_funcs + def test_finetune_zero_length(self, factory_func): + """Passing zero length should not fail""" + self._test_zero_length(factory_func(aux_num_out=32)) + + def _test_torchscript(self, model): + model.eval() + + batch_size, num_frames = 3, 1024 + + torch.manual_seed(0) + waveforms = torch.randn(batch_size, num_frames) + lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ]) + + ref_out, ref_len = model(waveforms, lengths) + + scripted = torch_script(model) + + hyp_out, hyp_len = scripted(waveforms, lengths) + + self.assertEqual(hyp_out, ref_out) + self.assertEqual(hyp_len, ref_len) + + @factory_funcs + def test_pretrain_torchscript(self, factory_func): + """Wav2Vec2Model should be scriptable""" + if factory_func is hubert_xlarge and os.name == 'nt' and os.environ.get('CI') == 'true': + self.skipTest( + 'hubert_xlarge is known to fail on Windows CI. ' + 'See https://github.com/pytorch/pytorch/issues/65776') + self._test_torchscript(factory_func()) + + @factory_funcs + def test_finetune_torchscript(self, factory_func): + """Wav2Vec2Model should be scriptable""" + if factory_func is hubert_xlarge and os.name == 'nt' and os.environ.get('CI') == 'true': + self.skipTest( + 'hubert_xlarge is known to fail on Windows CI. ' + 'See https://github.com/pytorch/pytorch/issues/65776') + self._test_torchscript(factory_func(aux_num_out=32)) + + def _test_quantize_smoke_test(self, model): + model.eval() + batch_size, num_frames = 3, 1024 + + # Remove the weight normalization forward hook + model.encoder.transformer.pos_conv_embed.__prepare_scriptable__() + quantized = torch.quantization.quantize_dynamic( + model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8) + + # A lazy way to check that Modules are different + assert str(quantized) != str(model), "Dynamic quantization did not modify the module." + + torch.manual_seed(0) + waveforms = torch.randn(batch_size, num_frames) + lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ]) + + _, _ = quantized(waveforms, lengths) + + @factory_funcs + @skipIfNoQengine + def test_quantize(self, factory_func): + """Wav2Vec2Model should support basic quantization""" + self._test_quantize_smoke_test(factory_func(aux_num_out=32)) + + def _test_quantize_torchscript(self, model): + model.eval() + + batch_size, num_frames = 3, 1024 + + # Remove the weight normalization forward hook + model.encoder.transformer.pos_conv_embed.__prepare_scriptable__() + quantized = torch.quantization.quantize_dynamic( + model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8) + + # A lazy way to check that Modules are different + assert str(quantized) != str(model), "Dynamic quantization did not modify the module." + + torch.manual_seed(0) + waveforms = torch.randn(batch_size, num_frames) + lengths = torch.randint(low=0, high=num_frames, size=[batch_size, ]) + + ref_out, ref_len = quantized(waveforms, lengths) + + # Script + scripted = torch_script(quantized) + + hyp_out, hyp_len = scripted(waveforms, lengths) + + self.assertEqual(hyp_out, ref_out) + self.assertEqual(hyp_len, ref_len) + + @factory_funcs + @skipIfNoQengine + def test_quantize_torchscript(self, factory_func): + """Quantized Wav2Vec2Model should be scriptable""" + self._test_quantize_torchscript(factory_func(aux_num_out=32)) diff --git a/test/torchaudio_unittest/sox_effect/__init__.py b/test/torchaudio_unittest/sox_effect/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/torchaudio_unittest/sox_effect/common.py b/test/torchaudio_unittest/sox_effect/common.py new file mode 100644 index 0000000000000000000000000000000000000000..00ac4a4565ddc027b1c68317f86cba44bf35416a --- /dev/null +++ b/test/torchaudio_unittest/sox_effect/common.py @@ -0,0 +1,26 @@ +import json + +from parameterized import param + +from torchaudio_unittest.common_utils import get_asset_path + + +def name_func(func, _, params): + if isinstance(params.args[0], str): + args = "_".join([str(arg) for arg in params.args]) + else: + args = "_".join([str(arg) for arg in params.args[0]]) + return f'{func.__name__}_{args}' + + +def load_params(*paths): + params = [] + with open(get_asset_path(*paths), 'r') as file: + for line in file: + data = json.loads(line) + for effect in data['effects']: + for i, arg in enumerate(effect): + if arg.startswith(""): + effect[i] = arg.replace("", get_asset_path()) + params.append(param(data)) + return params diff --git a/test/torchaudio_unittest/sox_effect/dataset_test.py b/test/torchaudio_unittest/sox_effect/dataset_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0b8cc9c4d8eccbc652d832bd9beb48cbccf12169 --- /dev/null +++ b/test/torchaudio_unittest/sox_effect/dataset_test.py @@ -0,0 +1,158 @@ +import sys +import platform +from unittest import skipIf +from typing import List, Tuple +from concurrent.futures import ProcessPoolExecutor + +import numpy as np +import torch +import torchaudio + +from torchaudio_unittest.common_utils import ( + TempDirMixin, + PytorchTestCase, + skipIfNoSox, + get_whitenoise, + save_wav, +) + + +class RandomPerturbationFile(torch.utils.data.Dataset): + """Given flist, apply random speed perturbation""" + def __init__(self, flist: List[str], sample_rate: int): + super().__init__() + self.flist = flist + self.sample_rate = sample_rate + self.rng = None + + def __getitem__(self, index): + speed = self.rng.uniform(0.5, 2.0) + effects = [ + ['gain', '-n', '-10'], + ['speed', f'{speed:.5f}'], # duration of data is 0.5 ~ 2.0 seconds. + ['rate', f'{self.sample_rate}'], + ['pad', '0', '1.5'], # add 1.5 seconds silence at the end + ['trim', '0', '2'], # get the first 2 seconds + ] + data, _ = torchaudio.sox_effects.apply_effects_file(self.flist[index], effects) + return data + + def __len__(self): + return len(self.flist) + + +class RandomPerturbationTensor(torch.utils.data.Dataset): + """Apply speed purturbation to (synthetic) Tensor data""" + def __init__(self, signals: List[Tuple[torch.Tensor, int]], sample_rate: int): + super().__init__() + self.signals = signals + self.sample_rate = sample_rate + self.rng = None + + def __getitem__(self, index): + speed = self.rng.uniform(0.5, 2.0) + effects = [ + ['gain', '-n', '-10'], + ['speed', f'{speed:.5f}'], # duration of data is 0.5 ~ 2.0 seconds. + ['rate', f'{self.sample_rate}'], + ['pad', '0', '1.5'], # add 1.5 seconds silence at the end + ['trim', '0', '2'], # get the first 2 seconds + ] + tensor, sample_rate = self.signals[index] + data, _ = torchaudio.sox_effects.apply_effects_tensor(tensor, sample_rate, effects) + return data + + def __len__(self): + return len(self.signals) + + +def init_random_seed(worker_id): + dataset = torch.utils.data.get_worker_info().dataset + dataset.rng = np.random.RandomState(worker_id) + + +@skipIfNoSox +@skipIf( + platform.system() == 'Darwin' and + sys.version_info.major == 3 and + sys.version_info.minor in [6, 7], + 'This test is known to get stuck for macOS with Python < 3.8. ' + 'See https://github.com/pytorch/pytorch/issues/46409' +) +class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase): + """Test `apply_effects_file` in multi-process dataloader setting""" + + def _generate_dataset(self, num_samples=128): + flist = [] + for i in range(num_samples): + sample_rate = np.random.choice([8000, 16000, 44100]) + dtype = np.random.choice(['float32', 'int32', 'int16', 'uint8']) + data = get_whitenoise(n_channels=2, sample_rate=sample_rate, duration=1, dtype=dtype) + path = self.get_temp_path(f'{i:03d}_{dtype}_{sample_rate}.wav') + save_wav(path, data, sample_rate) + flist.append(path) + return flist + + def test_apply_effects_file(self): + sample_rate = 12000 + flist = self._generate_dataset() + dataset = RandomPerturbationFile(flist, sample_rate) + loader = torch.utils.data.DataLoader( + dataset, batch_size=32, num_workers=16, + worker_init_fn=init_random_seed, + ) + for batch in loader: + assert batch.shape == (32, 2, 2 * sample_rate) + + def _generate_signals(self, num_samples=128): + signals = [] + for _ in range(num_samples): + sample_rate = np.random.choice([8000, 16000, 44100]) + data = get_whitenoise( + n_channels=2, sample_rate=sample_rate, duration=1, dtype='float32') + signals.append((data, sample_rate)) + return signals + + def test_apply_effects_tensor(self): + sample_rate = 12000 + signals = self._generate_signals() + dataset = RandomPerturbationTensor(signals, sample_rate) + loader = torch.utils.data.DataLoader( + dataset, batch_size=32, num_workers=16, + worker_init_fn=init_random_seed, + ) + for batch in loader: + assert batch.shape == (32, 2, 2 * sample_rate) + + +def speed(path): + wav, sample_rate = torchaudio.backend.sox_io_backend.load(path) + effects = [ + ['speed', '1.03756523535464655'], + ['rate', f'{sample_rate}'], + ] + return torchaudio.sox_effects.apply_effects_tensor(wav, sample_rate, effects)[0] + + +@skipIfNoSox +class TestProcessPoolExecutor(TempDirMixin, PytorchTestCase): + backend = "sox_io" + + def setUp(self): + sample_rate = 16000 + self.flist = [] + for i in range(10): + path = self.get_temp_path(f'{i}.wav') + data = get_whitenoise(n_channels=1, sample_rate=sample_rate, duration=1, dtype='float') + save_wav(path, data, sample_rate) + self.flist.append(path) + + def test_executor(self): + """Test that apply_effects_tensor with speed + rate does not crush + + https://github.com/pytorch/audio/issues/1021 + """ + executor = ProcessPoolExecutor(1) + futures = [executor.submit(speed, path) for path in self.flist] + for future in futures: + future.result() diff --git a/test/torchaudio_unittest/sox_effect/smoke_test.py b/test/torchaudio_unittest/sox_effect/smoke_test.py new file mode 100644 index 0000000000000000000000000000000000000000..70a6a346ea3136be6417e9d0a9ba93a5a235382b --- /dev/null +++ b/test/torchaudio_unittest/sox_effect/smoke_test.py @@ -0,0 +1,78 @@ +from torchaudio import sox_effects +from parameterized import parameterized + +from torchaudio_unittest.common_utils import ( + TempDirMixin, + TorchaudioTestCase, + skipIfNoSox, + get_wav_data, + get_sinusoid, + save_wav, +) +from .common import ( + load_params, +) + + +@skipIfNoSox +class SmokeTest(TempDirMixin, TorchaudioTestCase): + """Run smoke test on various effects + + The purpose of this test suite is to verify that sox_effect functionalities do not exhibit + abnormal behaviors. + + This test suite should be able to run without any additional tools (such as sox command), + however without such tools, the correctness of each function cannot be verified. + """ + @parameterized.expand( + load_params("sox_effect_test_args.jsonl"), + name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}', + ) + def test_apply_effects_tensor(self, args): + """`apply_effects_tensor` should not crash""" + effects = args['effects'] + num_channels = args.get("num_channels", 2) + input_sr = args.get("input_sample_rate", 8000) + original = get_sinusoid( + frequency=800, sample_rate=input_sr, + n_channels=num_channels, dtype='float32') + _found, _sr = sox_effects.apply_effects_tensor(original, input_sr, effects) + + @parameterized.expand( + load_params("sox_effect_test_args.jsonl"), + name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}', + ) + def test_apply_effects_file(self, args): + """`apply_effects_file` should return identical data as sox command""" + dtype = 'int32' + channels_first = True + effects = args['effects'] + num_channels = args.get("num_channels", 2) + input_sr = args.get("input_sample_rate", 8000) + + input_path = self.get_temp_path('input.wav') + data = get_wav_data(dtype, num_channels, channels_first=channels_first) + save_wav(input_path, data, input_sr, channels_first=channels_first) + + _found, _sr = sox_effects.apply_effects_file( + input_path, effects, normalize=False, channels_first=channels_first) + + @parameterized.expand( + load_params("sox_effect_test_args.jsonl"), + name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}', + ) + def test_apply_effects_fileobj(self, args): + """`apply_effects_file` should return identical data as sox command""" + dtype = 'int32' + channels_first = True + effects = args['effects'] + num_channels = args.get("num_channels", 2) + input_sr = args.get("input_sample_rate", 8000) + + input_path = self.get_temp_path('input.wav') + data = get_wav_data(dtype, num_channels, channels_first=channels_first) + save_wav(input_path, data, input_sr, channels_first=channels_first) + + with open(input_path, 'rb') as fileobj: + _found, _sr = sox_effects.apply_effects_file( + fileobj, effects, normalize=False, channels_first=channels_first) diff --git a/test/torchaudio_unittest/sox_effect/sox_effect_test.py b/test/torchaudio_unittest/sox_effect/sox_effect_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ca93e7d41b90895940c8d66570b5a42b74653a9e --- /dev/null +++ b/test/torchaudio_unittest/sox_effect/sox_effect_test.py @@ -0,0 +1,423 @@ +import io +import itertools +from pathlib import Path +import tarfile + +from parameterized import parameterized +from torchaudio import sox_effects +from torchaudio._internal import module_utils as _mod_utils + +from torchaudio_unittest.common_utils import ( + TempDirMixin, + HttpServerMixin, + PytorchTestCase, + skipIfNoSox, + skipIfNoModule, + skipIfNoExec, + get_asset_path, + get_sinusoid, + get_wav_data, + save_wav, + load_wav, + sox_utils, +) +from .common import ( + load_params, + name_func, +) + + +if _mod_utils.is_module_available("requests"): + import requests + + +@skipIfNoSox +class TestSoxEffects(PytorchTestCase): + def test_init(self): + """Calling init_sox_effects multiple times does not crush""" + for _ in range(3): + sox_effects.init_sox_effects() + + +@skipIfNoSox +class TestSoxEffectsTensor(TempDirMixin, PytorchTestCase): + """Test suite for `apply_effects_tensor` function""" + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16', 'uint8'], + [8000, 16000], + [1, 2, 4, 8], + [True, False] + )), name_func=name_func) + def test_apply_no_effect(self, dtype, sample_rate, num_channels, channels_first): + """`apply_effects_tensor` without effects should return identical data as input""" + original = get_wav_data(dtype, num_channels, channels_first=channels_first) + expected = original.clone() + found, output_sample_rate = sox_effects.apply_effects_tensor( + expected, sample_rate, [], channels_first) + + assert output_sample_rate == sample_rate + # SoxEffect should not alter the input Tensor object + self.assertEqual(original, expected) + # SoxEffect should not return the same Tensor object + assert expected is not found + # Returned Tensor should equal to the input Tensor + self.assertEqual(expected, found) + + @parameterized.expand( + load_params("sox_effect_test_args.jsonl"), + name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}', + ) + def test_apply_effects(self, args): + """`apply_effects_tensor` should return identical data as sox command""" + effects = args['effects'] + num_channels = args.get("num_channels", 2) + input_sr = args.get("input_sample_rate", 8000) + output_sr = args.get("output_sample_rate") + + input_path = self.get_temp_path('input.wav') + reference_path = self.get_temp_path('reference.wav') + + original = get_sinusoid( + frequency=800, sample_rate=input_sr, + n_channels=num_channels, dtype='float32') + save_wav(input_path, original, input_sr) + sox_utils.run_sox_effect( + input_path, reference_path, effects, output_sample_rate=output_sr) + + expected, expected_sr = load_wav(reference_path) + found, sr = sox_effects.apply_effects_tensor(original, input_sr, effects) + + assert sr == expected_sr + self.assertEqual(expected, found) + + +@skipIfNoSox +class TestSoxEffectsFile(TempDirMixin, PytorchTestCase): + """Test suite for `apply_effects_file` function""" + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16', 'uint8'], + [8000, 16000], + [1, 2, 4, 8], + [False, True], + )), name_func=name_func) + def test_apply_no_effect(self, dtype, sample_rate, num_channels, channels_first): + """`apply_effects_file` without effects should return identical data as input""" + path = self.get_temp_path('input.wav') + expected = get_wav_data(dtype, num_channels, channels_first=channels_first) + save_wav(path, expected, sample_rate, channels_first=channels_first) + + found, output_sample_rate = sox_effects.apply_effects_file( + path, [], normalize=False, channels_first=channels_first) + + assert output_sample_rate == sample_rate + self.assertEqual(expected, found) + + @parameterized.expand( + load_params("sox_effect_test_args.jsonl"), + name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}', + ) + def test_apply_effects_str(self, args): + """`apply_effects_file` should return identical data as sox command""" + dtype = 'int32' + channels_first = True + effects = args['effects'] + num_channels = args.get("num_channels", 2) + input_sr = args.get("input_sample_rate", 8000) + output_sr = args.get("output_sample_rate") + + input_path = self.get_temp_path('input.wav') + reference_path = self.get_temp_path('reference.wav') + data = get_wav_data(dtype, num_channels, channels_first=channels_first) + save_wav(input_path, data, input_sr, channels_first=channels_first) + sox_utils.run_sox_effect( + input_path, reference_path, effects, output_sample_rate=output_sr) + + expected, expected_sr = load_wav(reference_path) + found, sr = sox_effects.apply_effects_file( + input_path, effects, normalize=False, channels_first=channels_first) + + assert sr == expected_sr + self.assertEqual(found, expected) + + def test_apply_effects_path(self): + """`apply_effects_file` should return identical data as sox command when file path is given as a Path Object""" + dtype = 'int32' + channels_first = True + effects = [["hilbert"]] + num_channels = 2 + input_sr = 8000 + output_sr = 8000 + + input_path = self.get_temp_path('input.wav') + reference_path = self.get_temp_path('reference.wav') + data = get_wav_data(dtype, num_channels, channels_first=channels_first) + save_wav(input_path, data, input_sr, channels_first=channels_first) + sox_utils.run_sox_effect( + input_path, reference_path, effects, output_sample_rate=output_sr) + + expected, expected_sr = load_wav(reference_path) + found, sr = sox_effects.apply_effects_file( + Path(input_path), effects, normalize=False, channels_first=channels_first) + + assert sr == expected_sr + self.assertEqual(found, expected) + + +@skipIfNoSox +class TestFileFormats(TempDirMixin, PytorchTestCase): + """`apply_effects_file` gives the same result as sox on various file formats""" + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16', 'uint8'], + [8000, 16000], + [1, 2], + )), name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}') + def test_wav(self, dtype, sample_rate, num_channels): + """`apply_effects_file` works on various wav format""" + channels_first = True + effects = [['band', '300', '10']] + + input_path = self.get_temp_path('input.wav') + reference_path = self.get_temp_path('reference.wav') + data = get_wav_data(dtype, num_channels, channels_first=channels_first) + save_wav(input_path, data, sample_rate, channels_first=channels_first) + sox_utils.run_sox_effect(input_path, reference_path, effects) + + expected, expected_sr = load_wav(reference_path) + found, sr = sox_effects.apply_effects_file( + input_path, effects, normalize=False, channels_first=channels_first) + + assert sr == expected_sr + self.assertEqual(found, expected) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + )), name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}') + def test_mp3(self, sample_rate, num_channels): + """`apply_effects_file` works on various mp3 format""" + channels_first = True + effects = [['band', '300', '10']] + + input_path = self.get_temp_path('input.mp3') + reference_path = self.get_temp_path('reference.wav') + sox_utils.gen_audio_file(input_path, sample_rate, num_channels) + sox_utils.run_sox_effect(input_path, reference_path, effects) + + expected, expected_sr = load_wav(reference_path) + found, sr = sox_effects.apply_effects_file( + input_path, effects, channels_first=channels_first) + save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first) + + assert sr == expected_sr + self.assertEqual(found, expected, atol=1e-4, rtol=1e-8) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + )), name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}') + def test_flac(self, sample_rate, num_channels): + """`apply_effects_file` works on various flac format""" + channels_first = True + effects = [['band', '300', '10']] + + input_path = self.get_temp_path('input.flac') + reference_path = self.get_temp_path('reference.wav') + sox_utils.gen_audio_file(input_path, sample_rate, num_channels) + sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32) + + expected, expected_sr = load_wav(reference_path) + found, sr = sox_effects.apply_effects_file( + input_path, effects, channels_first=channels_first) + save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first) + + assert sr == expected_sr + self.assertEqual(found, expected) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + )), name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}') + def test_vorbis(self, sample_rate, num_channels): + """`apply_effects_file` works on various vorbis format""" + channels_first = True + effects = [['band', '300', '10']] + + input_path = self.get_temp_path('input.vorbis') + reference_path = self.get_temp_path('reference.wav') + sox_utils.gen_audio_file(input_path, sample_rate, num_channels) + sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32) + + expected, expected_sr = load_wav(reference_path) + found, sr = sox_effects.apply_effects_file( + input_path, effects, channels_first=channels_first) + save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first) + + assert sr == expected_sr + self.assertEqual(found, expected) + + +@skipIfNoSox +class TestApplyEffectFileWithoutExtension(PytorchTestCase): + def test_mp3(self): + """Providing format allows to read mp3 without extension + + libsox does not check header for mp3 + + https://github.com/pytorch/audio/issues/1040 + + The file was generated with the following command + ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext + """ + effects = [['band', '300', '10']] + path = get_asset_path("mp3_without_ext") + _, sr = sox_effects.apply_effects_file(path, effects, format="mp3") + assert sr == 16000 + + +@skipIfNoExec('sox') +@skipIfNoSox +class TestFileObject(TempDirMixin, PytorchTestCase): + @parameterized.expand([ + ('wav', None), + ('mp3', 128), + ('mp3', 320), + ('flac', 0), + ('flac', 5), + ('flac', 8), + ('vorbis', -1), + ('vorbis', 10), + ('amb', None), + ]) + def test_fileobj(self, ext, compression): + """Applying effects via file object works""" + sample_rate = 16000 + channels_first = True + effects = [['band', '300', '10']] + format_ = ext if ext in ['mp3'] else None + input_path = self.get_temp_path(f'input.{ext}') + reference_path = self.get_temp_path('reference.wav') + + sox_utils.gen_audio_file( + input_path, sample_rate, num_channels=2, compression=compression) + sox_utils.run_sox_effect( + input_path, reference_path, effects, output_bitdepth=32) + expected, expected_sr = load_wav(reference_path) + + with open(input_path, 'rb') as fileobj: + found, sr = sox_effects.apply_effects_file( + fileobj, effects, channels_first=channels_first, format=format_) + save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first) + assert sr == expected_sr + self.assertEqual(found, expected) + + @parameterized.expand([ + ('wav', None), + ('mp3', 128), + ('mp3', 320), + ('flac', 0), + ('flac', 5), + ('flac', 8), + ('vorbis', -1), + ('vorbis', 10), + ('amb', None), + ]) + def test_bytesio(self, ext, compression): + """Applying effects via BytesIO object works""" + sample_rate = 16000 + channels_first = True + effects = [['band', '300', '10']] + format_ = ext if ext in ['mp3'] else None + input_path = self.get_temp_path(f'input.{ext}') + reference_path = self.get_temp_path('reference.wav') + + sox_utils.gen_audio_file( + input_path, sample_rate, num_channels=2, compression=compression) + sox_utils.run_sox_effect( + input_path, reference_path, effects, output_bitdepth=32) + expected, expected_sr = load_wav(reference_path) + + with open(input_path, 'rb') as file_: + fileobj = io.BytesIO(file_.read()) + found, sr = sox_effects.apply_effects_file( + fileobj, effects, channels_first=channels_first, format=format_) + save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first) + assert sr == expected_sr + self.assertEqual(found, expected) + + @parameterized.expand([ + ('wav', None), + ('mp3', 128), + ('mp3', 320), + ('flac', 0), + ('flac', 5), + ('flac', 8), + ('vorbis', -1), + ('vorbis', 10), + ('amb', None), + ]) + def test_tarfile(self, ext, compression): + """Applying effects to compressed audio via file-like file works""" + sample_rate = 16000 + channels_first = True + effects = [['band', '300', '10']] + format_ = ext if ext in ['mp3'] else None + audio_file = f'input.{ext}' + + input_path = self.get_temp_path(audio_file) + reference_path = self.get_temp_path('reference.wav') + archive_path = self.get_temp_path('archive.tar.gz') + + sox_utils.gen_audio_file( + input_path, sample_rate, num_channels=2, compression=compression) + sox_utils.run_sox_effect( + input_path, reference_path, effects, output_bitdepth=32) + expected, expected_sr = load_wav(reference_path) + + with tarfile.TarFile(archive_path, 'w') as tarobj: + tarobj.add(input_path, arcname=audio_file) + with tarfile.TarFile(archive_path, 'r') as tarobj: + fileobj = tarobj.extractfile(audio_file) + found, sr = sox_effects.apply_effects_file( + fileobj, effects, channels_first=channels_first, format=format_) + save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first) + assert sr == expected_sr + self.assertEqual(found, expected) + + +@skipIfNoSox +@skipIfNoExec('sox') +@skipIfNoModule("requests") +class TestFileObjectHttp(HttpServerMixin, PytorchTestCase): + @parameterized.expand([ + ('wav', None), + ('mp3', 128), + ('mp3', 320), + ('flac', 0), + ('flac', 5), + ('flac', 8), + ('vorbis', -1), + ('vorbis', 10), + ('amb', None), + ]) + def test_requests(self, ext, compression): + sample_rate = 16000 + channels_first = True + effects = [['band', '300', '10']] + format_ = ext if ext in ['mp3'] else None + audio_file = f'input.{ext}' + input_path = self.get_temp_path(audio_file) + reference_path = self.get_temp_path('reference.wav') + + sox_utils.gen_audio_file( + input_path, sample_rate, num_channels=2, compression=compression) + sox_utils.run_sox_effect( + input_path, reference_path, effects, output_bitdepth=32) + expected, expected_sr = load_wav(reference_path) + + url = self.get_url(audio_file) + with requests.get(url, stream=True) as resp: + found, sr = sox_effects.apply_effects_file( + resp.raw, effects, channels_first=channels_first, format=format_) + save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first) + assert sr == expected_sr + self.assertEqual(found, expected) diff --git a/test/torchaudio_unittest/sox_effect/torchscript_test.py b/test/torchaudio_unittest/sox_effect/torchscript_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4bc81c2b597e983a5d6d1a6408d492dc85997c94 --- /dev/null +++ b/test/torchaudio_unittest/sox_effect/torchscript_test.py @@ -0,0 +1,96 @@ +from typing import List + +import torch +from torchaudio import sox_effects +from parameterized import parameterized + +from torchaudio_unittest.common_utils import ( + TempDirMixin, + TorchaudioTestCase, + skipIfNoSox, + get_sinusoid, + save_wav, + torch_script, +) +from .common import ( + load_params, +) + + +class SoxEffectTensorTransform(torch.nn.Module): + effects: List[List[str]] + + def __init__(self, effects: List[List[str]], sample_rate: int, channels_first: bool): + super().__init__() + self.effects = effects + self.sample_rate = sample_rate + self.channels_first = channels_first + + def forward(self, tensor: torch.Tensor): + return sox_effects.apply_effects_tensor( + tensor, self.sample_rate, self.effects, self.channels_first) + + +class SoxEffectFileTransform(torch.nn.Module): + effects: List[List[str]] + channels_first: bool + + def __init__(self, effects: List[List[str]], channels_first: bool): + super().__init__() + self.effects = effects + self.channels_first = channels_first + + def forward(self, path: str): + return sox_effects.apply_effects_file(path, self.effects, self.channels_first) + + +@skipIfNoSox +class TestTorchScript(TempDirMixin, TorchaudioTestCase): + @parameterized.expand( + load_params("sox_effect_test_args.jsonl"), + name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}', + ) + def test_apply_effects_tensor(self, args): + effects = args['effects'] + channels_first = True + num_channels = args.get("num_channels", 2) + input_sr = args.get("input_sample_rate", 8000) + + trans = SoxEffectTensorTransform(effects, input_sr, channels_first) + + trans = torch_script(trans) + + wav = get_sinusoid( + frequency=800, sample_rate=input_sr, + n_channels=num_channels, dtype='float32', channels_first=channels_first) + found, sr_found = trans(wav) + expected, sr_expected = sox_effects.apply_effects_tensor( + wav, input_sr, effects, channels_first) + + assert sr_found == sr_expected + self.assertEqual(expected, found) + + @parameterized.expand( + load_params("sox_effect_test_args.jsonl"), + name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}', + ) + def test_apply_effects_file(self, args): + effects = args['effects'] + channels_first = True + num_channels = args.get("num_channels", 2) + input_sr = args.get("input_sample_rate", 8000) + + trans = SoxEffectFileTransform(effects, channels_first) + trans = torch_script(trans) + + path = self.get_temp_path('input.wav') + wav = get_sinusoid( + frequency=800, sample_rate=input_sr, + n_channels=num_channels, dtype='float32', channels_first=channels_first) + save_wav(path, wav, sample_rate=input_sr, channels_first=channels_first) + + found, sr_found = trans(path) + expected, sr_expected = sox_effects.apply_effects_file(path, effects, channels_first) + + assert sr_found == sr_expected + self.assertEqual(expected, found) diff --git a/test/torchaudio_unittest/transforms/__init__.py b/test/torchaudio_unittest/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/torchaudio_unittest/transforms/autograd_cpu_test.py b/test/torchaudio_unittest/transforms/autograd_cpu_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f9d4aa199490bfcded6f4d6acd4f0f27cfc1b2d1 --- /dev/null +++ b/test/torchaudio_unittest/transforms/autograd_cpu_test.py @@ -0,0 +1,10 @@ +from torchaudio_unittest.common_utils import PytorchTestCase +from .autograd_test_impl import AutogradTestMixin, AutogradTestFloat32 + + +class AutogradCPUTest(AutogradTestMixin, PytorchTestCase): + device = 'cpu' + + +class AutogradRNNTCPUTest(AutogradTestFloat32, PytorchTestCase): + device = 'cpu' diff --git a/test/torchaudio_unittest/transforms/autograd_cuda_test.py b/test/torchaudio_unittest/transforms/autograd_cuda_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7565f88540acaab817ce2faa09279a3381dd0738 --- /dev/null +++ b/test/torchaudio_unittest/transforms/autograd_cuda_test.py @@ -0,0 +1,15 @@ +from torchaudio_unittest.common_utils import ( + PytorchTestCase, + skipIfNoCuda, +) +from .autograd_test_impl import AutogradTestMixin, AutogradTestFloat32 + + +@skipIfNoCuda +class AutogradCUDATest(AutogradTestMixin, PytorchTestCase): + device = 'cuda' + + +@skipIfNoCuda +class AutogradRNNTCUDATest(AutogradTestFloat32, PytorchTestCase): + device = 'cuda' diff --git a/test/torchaudio_unittest/transforms/autograd_test_impl.py b/test/torchaudio_unittest/transforms/autograd_test_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..8bf37b0a394eb137dfd5aed593867d869050ee44 --- /dev/null +++ b/test/torchaudio_unittest/transforms/autograd_test_impl.py @@ -0,0 +1,337 @@ +from typing import List +import unittest + +from parameterized import parameterized +import torch +from torch.autograd import gradcheck, gradgradcheck +import torchaudio.transforms as T + +from torchaudio_unittest.common_utils import ( + TestBaseMixin, + get_whitenoise, + get_spectrogram, + nested_params, + rnnt_utils, +) + + +class _DeterministicWrapper(torch.nn.Module): + """Helper transform wrapper to make the given transform deterministic""" + def __init__(self, transform, seed=0): + super().__init__() + self.seed = seed + self.transform = transform + + def forward(self, input: torch.Tensor): + torch.random.manual_seed(self.seed) + return self.transform(input) + + +class AutogradTestMixin(TestBaseMixin): + def assert_grad( + self, + transform: torch.nn.Module, + inputs: List[torch.Tensor], + *, + nondet_tol: float = 0.0, + ): + transform = transform.to(dtype=torch.float64, device=self.device) + + # gradcheck and gradgradcheck only pass if the input tensors are of dtype `torch.double` or + # `torch.cdouble`, when the default eps and tolerance values are used. + inputs_ = [] + for i in inputs: + if torch.is_tensor(i): + i = i.to( + dtype=torch.cdouble if i.is_complex() else torch.double, + device=self.device) + i.requires_grad = True + inputs_.append(i) + assert gradcheck(transform, inputs_) + assert gradgradcheck(transform, inputs_, nondet_tol=nondet_tol) + + @parameterized.expand([ + ({'pad': 0, 'normalized': False, 'power': None, 'return_complex': True}, ), + ({'pad': 3, 'normalized': False, 'power': None, 'return_complex': True}, ), + ({'pad': 0, 'normalized': True, 'power': None, 'return_complex': True}, ), + ({'pad': 3, 'normalized': True, 'power': None, 'return_complex': True}, ), + ({'pad': 0, 'normalized': False, 'power': None}, ), + ({'pad': 3, 'normalized': False, 'power': None}, ), + ({'pad': 0, 'normalized': True, 'power': None}, ), + ({'pad': 3, 'normalized': True, 'power': None}, ), + ({'pad': 0, 'normalized': False, 'power': 1.0}, ), + ({'pad': 3, 'normalized': False, 'power': 1.0}, ), + ({'pad': 0, 'normalized': True, 'power': 1.0}, ), + ({'pad': 3, 'normalized': True, 'power': 1.0}, ), + ({'pad': 0, 'normalized': False, 'power': 2.0}, ), + ({'pad': 3, 'normalized': False, 'power': 2.0}, ), + ({'pad': 0, 'normalized': True, 'power': 2.0}, ), + ({'pad': 3, 'normalized': True, 'power': 2.0}, ), + ]) + def test_spectrogram(self, kwargs): + # replication_pad1d_backward_cuda is not deteministic and + # gives very small (~2.7756e-17) difference. + # + # See https://github.com/pytorch/pytorch/issues/54093 + transform = T.Spectrogram(**kwargs) + waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2) + self.assert_grad(transform, [waveform], nondet_tol=1e-10) + + @parameterized.expand([(False, ), (True, )]) + def test_inverse_spectrogram(self, return_complex): + # create a realistic input: + waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2) + length = waveform.shape[-1] + spectrogram = get_spectrogram(waveform, n_fft=400) + if not return_complex: + spectrogram = torch.view_as_real(spectrogram) + + # test + inv_transform = T.InverseSpectrogram(n_fft=400) + self.assert_grad(inv_transform, [spectrogram, length]) + + def test_melspectrogram(self): + # replication_pad1d_backward_cuda is not deteministic and + # gives very small (~2.7756e-17) difference. + # + # See https://github.com/pytorch/pytorch/issues/54093 + sample_rate = 8000 + transform = T.MelSpectrogram(sample_rate=sample_rate) + waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2) + self.assert_grad(transform, [waveform], nondet_tol=1e-10) + + @nested_params( + [0, 0.99], + [False, True], + ) + def test_griffinlim(self, momentum, rand_init): + n_fft = 400 + power = 1 + n_iter = 3 + spec = get_spectrogram( + get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2), + n_fft=n_fft, power=power) + transform = _DeterministicWrapper( + T.GriffinLim(n_fft=n_fft, n_iter=n_iter, momentum=momentum, rand_init=rand_init, power=power)) + self.assert_grad(transform, [spec]) + + @parameterized.expand([(False, ), (True, )]) + def test_mfcc(self, log_mels): + sample_rate = 8000 + transform = T.MFCC(sample_rate=sample_rate, log_mels=log_mels) + waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2) + self.assert_grad(transform, [waveform]) + + @parameterized.expand([(False, ), (True, )]) + def test_lfcc(self, log_lf): + sample_rate = 8000 + transform = T.LFCC(sample_rate=sample_rate, log_lf=log_lf) + waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2) + self.assert_grad(transform, [waveform]) + + def test_compute_deltas(self): + transform = T.ComputeDeltas() + spec = torch.rand(10, 20) + self.assert_grad(transform, [spec]) + + @parameterized.expand([(8000, 8000), (8000, 4000), (4000, 8000)]) + def test_resample(self, orig_freq, new_freq): + transform = T.Resample(orig_freq=orig_freq, new_freq=new_freq) + waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2) + self.assert_grad(transform, [waveform]) + + @parameterized.expand([("linear", ), ("exponential", ), ("logarithmic", ), ("quarter_sine", ), ("half_sine", )]) + def test_fade(self, fade_shape): + transform = T.Fade(fade_shape=fade_shape) + waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2) + self.assert_grad(transform, [waveform], nondet_tol=1e-10) + + @parameterized.expand([(T.TimeMasking,), (T.FrequencyMasking,)]) + def test_masking(self, masking_transform): + sample_rate = 8000 + n_fft = 400 + spectrogram = get_spectrogram( + get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2), + n_fft=n_fft, power=1) + deterministic_transform = _DeterministicWrapper(masking_transform(400)) + self.assert_grad(deterministic_transform, [spectrogram]) + + @parameterized.expand([(T.TimeMasking,), (T.FrequencyMasking,)]) + def test_masking_iid(self, masking_transform): + sample_rate = 8000 + n_fft = 400 + specs = [get_spectrogram( + get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2, seed=i), + n_fft=n_fft, power=1) + for i in range(3) + ] + + batch = torch.stack(specs) + assert batch.ndim == 4 + deterministic_transform = _DeterministicWrapper(masking_transform(400, True)) + self.assert_grad(deterministic_transform, [batch]) + + def test_spectral_centroid(self): + sample_rate = 8000 + transform = T.SpectralCentroid(sample_rate=sample_rate) + waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2) + self.assert_grad(transform, [waveform], nondet_tol=1e-10) + + def test_amplitude_to_db(self): + sample_rate = 8000 + transform = T.AmplitudeToDB() + waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2) + self.assert_grad(transform, [waveform]) + + def test_melscale(self): + sample_rate = 8000 + n_fft = 400 + n_mels = n_fft // 2 + 1 + transform = T.MelScale(sample_rate=sample_rate, n_mels=n_mels) + spec = get_spectrogram( + get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2), + n_fft=n_fft, power=1) + self.assert_grad(transform, [spec]) + + @parameterized.expand([(1.5, "amplitude"), (2, "power"), (10, "db")]) + def test_vol(self, gain, gain_type): + sample_rate = 8000 + transform = T.Vol(gain=gain, gain_type=gain_type) + waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2) + self.assert_grad(transform, [waveform]) + + @parameterized.expand([ + ({'cmn_window': 100, 'min_cmn_window': 50, 'center': False, 'norm_vars': False}, ), + ({'cmn_window': 100, 'min_cmn_window': 50, 'center': True, 'norm_vars': False}, ), + ({'cmn_window': 100, 'min_cmn_window': 50, 'center': False, 'norm_vars': True}, ), + ({'cmn_window': 100, 'min_cmn_window': 50, 'center': True, 'norm_vars': True}, ), + ]) + def test_sliding_window_cmn(self, kwargs): + n_fft = 10 + power = 1 + spec = get_spectrogram( + get_whitenoise(sample_rate=200, duration=0.05, n_channels=2), + n_fft=n_fft, power=power) + spec_reshaped = spec.transpose(-1, -2) + + transform = T.SlidingWindowCmn(**kwargs) + self.assert_grad(transform, [spec_reshaped]) + + @unittest.expectedFailure + def test_timestretch_zeros_fail(self): + """Test that ``T.TimeStretch`` fails gradcheck at 0 + + This is because ``F.phase_vocoder`` converts data from cartesian to polar coordinate, + which performs ``atan2(img, real)``, and gradient is not defined at 0. + """ + n_fft = 16 + transform = T.TimeStretch(n_freq=n_fft // 2 + 1, fixed_rate=0.99) + waveform = torch.zeros(2, 40) + spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None) + self.assert_grad(transform, [spectrogram]) + + @nested_params( + [0.7, 0.8, 0.9, 1.0, 1.3], + [False, True], + ) + def test_timestretch_non_zero(self, rate, test_pseudo_complex): + """Verify that ``T.TimeStretch`` does not fail if it's not close to 0 + + ``T.TimeStrech`` is not differentiable around 0, so this test checks the differentiability + for cases where input is not zero. + + As tested above, when spectrogram contains values close to zero, the gradients are unstable + and gradcheck fails. + + In this test, we generate spectrogram from random signal, then we push the points around + zero away from the origin. + + This process does not reflect the real use-case, and it is not practical for users, but + this helps us understand to what degree the function is differentiable and when not. + """ + n_fft = 16 + transform = T.TimeStretch(n_freq=n_fft // 2 + 1, fixed_rate=rate) + waveform = get_whitenoise(sample_rate=40, duration=1, n_channels=2) + spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None) + + # 1e-3 is too small (on CPU) + epsilon = 1e-2 + too_close = spectrogram.abs() < epsilon + spectrogram[too_close] = epsilon * spectrogram[too_close] / spectrogram[too_close].abs() + if test_pseudo_complex: + spectrogram = torch.view_as_real(spectrogram) + self.assert_grad(transform, [spectrogram]) + + def test_psd(self): + transform = T.PSD() + waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2) + spectrogram = get_spectrogram(waveform, n_fft=400) + self.assert_grad(transform, [spectrogram]) + + @parameterized.expand([ + [True], + [False], + ]) + def test_psd_with_mask(self, multi_mask): + transform = T.PSD(multi_mask=multi_mask) + waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2) + spectrogram = get_spectrogram(waveform, n_fft=400) + if multi_mask: + mask = torch.rand(spectrogram.shape[-3:]) + else: + mask = torch.rand(spectrogram.shape[-2:]) + + self.assert_grad(transform, [spectrogram, mask]) + + @parameterized.expand([ + "ref_channel", + # stv_power test time too long, comment for now + # "stv_power", + # stv_evd will fail since the eigenvalues are not distinct + # "stv_evd", + ]) + def test_mvdr(self, solution): + transform = T.MVDR(solution=solution) + waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2) + spectrogram = get_spectrogram(waveform, n_fft=400) + mask_s = torch.rand(spectrogram.shape[-2:]) + mask_n = torch.rand(spectrogram.shape[-2:]) + self.assert_grad(transform, [spectrogram, mask_s, mask_n]) + + +class AutogradTestFloat32(TestBaseMixin): + def assert_grad( + self, + transform: torch.nn.Module, + inputs: List[torch.Tensor], + ): + inputs_ = [] + for i in inputs: + if torch.is_tensor(i): + i = i.to(dtype=torch.float32, device=self.device) + inputs_.append(i) + # gradcheck with float32 requires higher atol and epsilon + assert gradcheck(transform, inputs, eps=1e-3, atol=1e-3, nondet_tol=0.) + + @parameterized.expand([ + (rnnt_utils.get_B1_T10_U3_D4_data, ), + (rnnt_utils.get_B2_T4_U3_D3_data, ), + (rnnt_utils.get_B1_T2_U3_D5_data, ), + ]) + def test_rnnt_loss(self, data_func): + def get_data(data_func, device): + data = data_func() + if type(data) == tuple: + data = data[0] + return data + + data = get_data(data_func, self.device) + inputs = ( + data["logits"].to(torch.float32), + data["targets"], + data["logit_lengths"], + data["target_lengths"], + ) + loss = T.RNNTLoss(blank=data["blank"]) + + self.assert_grad(loss, inputs) diff --git a/test/torchaudio_unittest/transforms/batch_consistency_test.py b/test/torchaudio_unittest/transforms/batch_consistency_test.py new file mode 100644 index 0000000000000000000000000000000000000000..aa1f38174ce6f26ff4d60838acb58a25c3050259 --- /dev/null +++ b/test/torchaudio_unittest/transforms/batch_consistency_test.py @@ -0,0 +1,228 @@ +"""Test numerical consistency among single input and batched input.""" +import torch +from parameterized import parameterized +from torchaudio import transforms as T + +from torchaudio_unittest import common_utils + + +class TestTransforms(common_utils.TorchaudioTestCase): + """Test suite for classes defined in `transforms` module""" + backend = 'default' + + def assert_batch_consistency( + self, transform, batch, *args, atol=1e-8, rtol=1e-5, seed=42, + **kwargs): + n = batch.size(0) + + # Compute items separately, then batch the result + torch.random.manual_seed(seed) + items_input = batch.clone() + items_result = torch.stack([ + transform(items_input[i], *args, **kwargs) for i in range(n) + ]) + + # Batch the input and run + torch.random.manual_seed(seed) + batch_input = batch.clone() + batch_result = transform(batch_input, *args, **kwargs) + + self.assertEqual(items_input, batch_input, rtol=rtol, atol=atol) + self.assertEqual(items_result, batch_result, rtol=rtol, atol=atol) + + def test_batch_AmplitudeToDB(self): + spec = torch.rand((3, 2, 6, 201)) + transform = T.AmplitudeToDB() + + self.assert_batch_consistency(transform, spec) + + def test_batch_Resample(self): + waveform = torch.randn(3, 2, 2786) + transform = T.Resample() + + self.assert_batch_consistency(transform, waveform) + + def test_batch_MelScale(self): + specgram = torch.randn(3, 2, 201, 256) + transform = T.MelScale() + + self.assert_batch_consistency(transform, specgram) + + def test_batch_InverseMelScale(self): + n_mels = 32 + n_stft = 5 + mel_spec = torch.randn(3, 2, n_mels, 32) ** 2 + transform = T.InverseMelScale(n_stft, n_mels) + + # Because InverseMelScale runs SGD on randomly initialized values so they do not yield + # exactly same result. For this reason, tolerance is very relaxed here. + self.assert_batch_consistency(transform, mel_spec, atol=1.0, rtol=1e-5) + + def test_batch_compute_deltas(self): + specgram = torch.randn(3, 2, 31, 2786) + transform = T.ComputeDeltas() + + self.assert_batch_consistency(transform, specgram) + + def test_batch_mulaw(self): + waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6) + waveform = waveform.reshape(3, 2, -1) + + # Single then transform then batch + expected = [T.MuLawEncoding()(waveform[i]) for i in range(3)] + expected = torch.stack(expected) + + # Batch then transform + computed = T.MuLawEncoding()(waveform) + + # shape = (3, 2, 201, 1394) + self.assertEqual(computed, expected) + + # Single then transform then batch + expected_decoded = [T.MuLawDecoding()(expected[i]) for i in range(3)] + expected_decoded = torch.stack(expected_decoded) + + # Batch then transform + computed_decoded = T.MuLawDecoding()(computed) + + # shape = (3, 2, 201, 1394) + self.assertEqual(computed_decoded, expected_decoded) + + def test_batch_spectrogram(self): + waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6) + waveform = waveform.reshape(3, 2, -1) + transform = T.Spectrogram() + + self.assert_batch_consistency(transform, waveform) + + def test_batch_inverse_spectrogram(self): + waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6) + specgram = common_utils.get_spectrogram(waveform, n_fft=400) + specgram = specgram.reshape(3, 2, specgram.shape[-2], specgram.shape[-1]) + transform = T.InverseSpectrogram(n_fft=400) + + self.assert_batch_consistency(transform, specgram) + + def test_batch_melspectrogram(self): + waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6) + waveform = waveform.reshape(3, 2, -1) + transform = T.MelSpectrogram() + + self.assert_batch_consistency(transform, waveform) + + def test_batch_mfcc(self): + waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6) + waveform = waveform.reshape(3, 2, -1) + transform = T.MFCC() + + self.assert_batch_consistency(transform, waveform, atol=1e-4, rtol=1e-5) + + def test_batch_lfcc(self): + waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6) + waveform = waveform.reshape(3, 2, -1) + transform = T.LFCC() + + self.assert_batch_consistency(transform, waveform, atol=1e-4, rtol=1e-5) + + @parameterized.expand([(True, ), (False, )]) + def test_batch_TimeStretch(self, test_pseudo_complex): + rate = 2 + num_freq = 1025 + num_frames = 400 + batch = 3 + + spec = torch.randn(batch, num_freq, num_frames, dtype=torch.complex64) + if test_pseudo_complex: + spec = torch.view_as_real(spec) + + transform = T.TimeStretch( + fixed_rate=rate, + n_freq=num_freq, + hop_length=512 + ) + + self.assert_batch_consistency(transform, spec, atol=1e-5, rtol=1e-5) + + def test_batch_Fade(self): + waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6) + waveform = waveform.reshape(3, 2, -1) + fade_in_len = 3000 + fade_out_len = 3000 + transform = T.Fade(fade_in_len, fade_out_len) + + self.assert_batch_consistency(transform, waveform) + + def test_batch_Vol(self): + waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6) + waveform = waveform.reshape(3, 2, -1) + transform = T.Vol(gain=1.1) + + self.assert_batch_consistency(transform, waveform) + + def test_batch_spectral_centroid(self): + sample_rate = 44100 + waveform = common_utils.get_whitenoise(sample_rate=sample_rate, n_channels=6) + waveform = waveform.reshape(3, 2, -1) + transform = T.SpectralCentroid(sample_rate) + + self.assert_batch_consistency(transform, waveform) + + def test_batch_pitch_shift(self): + sample_rate = 8000 + n_steps = -2 + waveform = common_utils.get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=6) + waveform = waveform.reshape(3, 2, -1) + transform = T.PitchShift(sample_rate, n_steps, n_fft=400) + + self.assert_batch_consistency(transform, waveform) + + def test_batch_PSD(self): + waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6) + specgram = common_utils.get_spectrogram(waveform, n_fft=400) + specgram = specgram.reshape(3, 2, specgram.shape[-2], specgram.shape[-1]) + transform = T.PSD() + + self.assert_batch_consistency(transform, specgram) + + def test_batch_PSD_with_mask(self): + waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6) + waveform = waveform.to(torch.double) + specgram = common_utils.get_spectrogram(waveform, n_fft=400) + specgram = specgram.reshape(3, 2, specgram.shape[-2], specgram.shape[-1]) + mask = torch.rand((3, specgram.shape[-2], specgram.shape[-1])) + transform = T.PSD() + + # Single then transform then batch + expected = [transform(specgram[i], mask[i]) for i in range(3)] + expected = torch.stack(expected) + + # Batch then transform + computed = transform(specgram, mask) + + self.assertEqual(computed, expected) + + @parameterized.expand([ + [True], + [False], + ]) + def test_MVDR(self, multi_mask): + waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6) + waveform = waveform.to(torch.double) + specgram = common_utils.get_spectrogram(waveform, n_fft=400) + specgram = specgram.reshape(3, 2, specgram.shape[-2], specgram.shape[-1]) + if multi_mask: + mask_s = torch.rand((3, 2, specgram.shape[-2], specgram.shape[-1])) + mask_n = torch.rand((3, 2, specgram.shape[-2], specgram.shape[-1])) + else: + mask_s = torch.rand((3, specgram.shape[-2], specgram.shape[-1])) + mask_n = torch.rand((3, specgram.shape[-2], specgram.shape[-1])) + transform = T.MVDR(multi_mask=multi_mask) + + # Single then transform then batch + expected = [transform(specgram[i], mask_s[i], mask_n[i]) for i in range(3)] + expected = torch.stack(expected) + + # Batch then transform + computed = transform(specgram, mask_s, mask_n) + + self.assertEqual(computed, expected) diff --git a/test/torchaudio_unittest/transforms/kaldi_compatibility_cpu_test.py b/test/torchaudio_unittest/transforms/kaldi_compatibility_cpu_test.py new file mode 100644 index 0000000000000000000000000000000000000000..43be412b4798c7cd141130d3933d61b18e02191c --- /dev/null +++ b/test/torchaudio_unittest/transforms/kaldi_compatibility_cpu_test.py @@ -0,0 +1,14 @@ +import torch + +from torchaudio_unittest import common_utils +from .kaldi_compatibility_impl import Kaldi + + +class TestKaldiFloat32(Kaldi, common_utils.PytorchTestCase): + dtype = torch.float32 + device = torch.device('cpu') + + +class TestKaldiFloat64(Kaldi, common_utils.PytorchTestCase): + dtype = torch.float64 + device = torch.device('cpu') diff --git a/test/torchaudio_unittest/transforms/kaldi_compatibility_cuda_test.py b/test/torchaudio_unittest/transforms/kaldi_compatibility_cuda_test.py new file mode 100644 index 0000000000000000000000000000000000000000..28adb7fce570620d25d856f000ea1dc2ec243bec --- /dev/null +++ b/test/torchaudio_unittest/transforms/kaldi_compatibility_cuda_test.py @@ -0,0 +1,16 @@ +import torch + +from torchaudio_unittest import common_utils +from .kaldi_compatibility_impl import Kaldi + + +@common_utils.skipIfNoCuda +class TestKaldiFloat32(Kaldi, common_utils.PytorchTestCase): + dtype = torch.float32 + device = torch.device('cuda') + + +@common_utils.skipIfNoCuda +class TestKaldiFloat64(Kaldi, common_utils.PytorchTestCase): + dtype = torch.float64 + device = torch.device('cuda') diff --git a/test/torchaudio_unittest/transforms/kaldi_compatibility_impl.py b/test/torchaudio_unittest/transforms/kaldi_compatibility_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..9a0e6ab81d2c52cfbeb7b6930e7fa946d3f17999 --- /dev/null +++ b/test/torchaudio_unittest/transforms/kaldi_compatibility_impl.py @@ -0,0 +1,55 @@ +"""Test suites for checking numerical compatibility against Kaldi""" +import torchaudio.compliance.kaldi +from parameterized import parameterized + +from torchaudio_unittest.common_utils import ( + TestBaseMixin, + TempDirMixin, + load_params, + skipIfNoExec, + get_asset_path, + load_wav, +) +from torchaudio_unittest.common_utils.kaldi_utils import ( + convert_args, + run_kaldi, +) + + +class Kaldi(TempDirMixin, TestBaseMixin): + def assert_equal(self, output, *, expected, rtol=None, atol=None): + expected = expected.to(dtype=self.dtype, device=self.device) + self.assertEqual(output, expected, rtol=rtol, atol=atol) + + @parameterized.expand(load_params('kaldi_test_fbank_args.jsonl')) + @skipIfNoExec('compute-fbank-feats') + def test_fbank(self, kwargs): + """fbank should be numerically compatible with compute-fbank-feats""" + wave_file = get_asset_path('kaldi_file.wav') + waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device) + result = torchaudio.compliance.kaldi.fbank(waveform, **kwargs) + command = ['compute-fbank-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-'] + kaldi_result = run_kaldi(command, 'scp', wave_file) + self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) + + @parameterized.expand(load_params('kaldi_test_spectrogram_args.jsonl')) + @skipIfNoExec('compute-spectrogram-feats') + def test_spectrogram(self, kwargs): + """spectrogram should be numerically compatible with compute-spectrogram-feats""" + wave_file = get_asset_path('kaldi_file.wav') + waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device) + result = torchaudio.compliance.kaldi.spectrogram(waveform, **kwargs) + command = ['compute-spectrogram-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-'] + kaldi_result = run_kaldi(command, 'scp', wave_file) + self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) + + @parameterized.expand(load_params('kaldi_test_mfcc_args.jsonl')) + @skipIfNoExec('compute-mfcc-feats') + def test_mfcc(self, kwargs): + """mfcc should be numerically compatible with compute-mfcc-feats""" + wave_file = get_asset_path('kaldi_file.wav') + waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device) + result = torchaudio.compliance.kaldi.mfcc(waveform, **kwargs) + command = ['compute-mfcc-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-'] + kaldi_result = run_kaldi(command, 'scp', wave_file) + self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) diff --git a/test/torchaudio_unittest/transforms/librosa_compatibility_cpu_test.py b/test/torchaudio_unittest/transforms/librosa_compatibility_cpu_test.py new file mode 100644 index 0000000000000000000000000000000000000000..300f4a3f71680fbf671f5e4f2ac5a95410cb5614 --- /dev/null +++ b/test/torchaudio_unittest/transforms/librosa_compatibility_cpu_test.py @@ -0,0 +1,9 @@ +import torch + +from torchaudio_unittest.common_utils import PytorchTestCase +from .librosa_compatibility_test_impl import TransformsTestBase + + +class TestTransforms(TransformsTestBase, PytorchTestCase): + dtype = torch.float64 + device = torch.device('cpu') diff --git a/test/torchaudio_unittest/transforms/librosa_compatibility_cuda_test.py b/test/torchaudio_unittest/transforms/librosa_compatibility_cuda_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d553458aca6909de7665b878b3a4174522083485 --- /dev/null +++ b/test/torchaudio_unittest/transforms/librosa_compatibility_cuda_test.py @@ -0,0 +1,10 @@ +import torch + +from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda +from .librosa_compatibility_test_impl import TransformsTestBase + + +@skipIfNoCuda +class TestTransforms(TransformsTestBase, PytorchTestCase): + dtype = torch.float64 + device = torch.device('cuda') diff --git a/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py b/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..9bc8edbb2046cb1ad5735c23bdfffe99e5ee5b53 --- /dev/null +++ b/test/torchaudio_unittest/transforms/librosa_compatibility_test_impl.py @@ -0,0 +1,141 @@ +import unittest + +import torch +import torchaudio.transforms as T +from torchaudio._internal.module_utils import is_module_available +from parameterized import param, parameterized + +from torchaudio_unittest.common_utils import ( + TestBaseMixin, + get_whitenoise, + get_sinusoid, + get_spectrogram, + nested_params, +) + +LIBROSA_AVAILABLE = is_module_available('librosa') + +if LIBROSA_AVAILABLE: + import librosa + + +@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available") +class TransformsTestBase(TestBaseMixin): + @parameterized.expand([ + param(n_fft=400, hop_length=200, power=2.0), + param(n_fft=600, hop_length=100, power=2.0), + param(n_fft=400, hop_length=200, power=3.0), + param(n_fft=200, hop_length=50, power=2.0), + ]) + def test_Spectrogram(self, n_fft, hop_length, power): + sample_rate = 16000 + waveform = get_whitenoise( + sample_rate=sample_rate, n_channels=1, + ).to(self.device, self.dtype) + + expected = librosa.core.spectrum._spectrogram( + y=waveform[0].cpu().numpy(), + n_fft=n_fft, hop_length=hop_length, power=power)[0] + + result = T.Spectrogram( + n_fft=n_fft, hop_length=hop_length, power=power, + ).to(self.device, self.dtype)(waveform)[0] + self.assertEqual(result, torch.from_numpy(expected), atol=1e-5, rtol=1e-5) + + def test_Spectrogram_complex(self): + n_fft = 400 + hop_length = 200 + sample_rate = 16000 + waveform = get_whitenoise( + sample_rate=sample_rate, n_channels=1, + ).to(self.device, self.dtype) + + expected = librosa.core.spectrum._spectrogram( + y=waveform[0].cpu().numpy(), + n_fft=n_fft, hop_length=hop_length, power=1)[0] + + result = T.Spectrogram( + n_fft=n_fft, hop_length=hop_length, power=None, return_complex=True, + ).to(self.device, self.dtype)(waveform)[0] + self.assertEqual(result.abs(), torch.from_numpy(expected), atol=1e-5, rtol=1e-5) + + @nested_params( + [ + param(n_fft=400, hop_length=200, n_mels=64), + param(n_fft=600, hop_length=100, n_mels=128), + param(n_fft=200, hop_length=50, n_mels=32), + ], + [param(norm=norm) for norm in [None, 'slaney']], + [param(mel_scale=mel_scale) for mel_scale in ['htk', 'slaney']], + ) + def test_MelSpectrogram(self, n_fft, hop_length, n_mels, norm, mel_scale): + sample_rate = 16000 + waveform = get_sinusoid( + sample_rate=sample_rate, n_channels=1, + ).to(self.device, self.dtype) + + expected = librosa.feature.melspectrogram( + y=waveform[0].cpu().numpy(), + sr=sample_rate, n_fft=n_fft, + hop_length=hop_length, n_mels=n_mels, norm=norm, + htk=mel_scale == "htk") + result = T.MelSpectrogram( + sample_rate=sample_rate, window_fn=torch.hann_window, + hop_length=hop_length, n_mels=n_mels, + n_fft=n_fft, norm=norm, mel_scale=mel_scale, + ).to(self.device, self.dtype)(waveform)[0] + self.assertEqual(result, torch.from_numpy(expected), atol=5e-4, rtol=1e-5) + + def test_magnitude_to_db(self): + spectrogram = get_spectrogram( + get_whitenoise(), n_fft=400, power=2).to(self.device, self.dtype) + result = T.AmplitudeToDB('magnitude', 80.).to(self.device, self.dtype)(spectrogram)[0] + expected = librosa.core.spectrum.amplitude_to_db(spectrogram[0].cpu().numpy()) + self.assertEqual(result, torch.from_numpy(expected)) + + def test_power_to_db(self): + spectrogram = get_spectrogram( + get_whitenoise(), n_fft=400, power=2).to(self.device, self.dtype) + result = T.AmplitudeToDB('power', 80.).to(self.device, self.dtype)(spectrogram)[0] + expected = librosa.core.spectrum.power_to_db(spectrogram[0].cpu().numpy()) + self.assertEqual(result, torch.from_numpy(expected)) + + @nested_params([ + param(n_fft=400, hop_length=200, n_mels=64, n_mfcc=40), + param(n_fft=600, hop_length=100, n_mels=128, n_mfcc=20), + param(n_fft=200, hop_length=50, n_mels=32, n_mfcc=25), + ]) + def test_mfcc(self, n_fft, hop_length, n_mels, n_mfcc): + sample_rate = 16000 + waveform = get_whitenoise( + sample_rate=sample_rate, n_channels=1).to(self.device, self.dtype) + result = T.MFCC( + sample_rate=sample_rate, n_mfcc=n_mfcc, norm='ortho', + melkwargs={'hop_length': hop_length, 'n_fft': n_fft, 'n_mels': n_mels}, + ).to(self.device, self.dtype)(waveform)[0] + + melspec = librosa.feature.melspectrogram( + y=waveform[0].cpu().numpy(), sr=sample_rate, n_fft=n_fft, + win_length=n_fft, hop_length=hop_length, + n_mels=n_mels, htk=True, norm=None) + expected = librosa.feature.mfcc( + S=librosa.core.spectrum.power_to_db(melspec), + n_mfcc=n_mfcc, dct_type=2, norm='ortho') + self.assertEqual(result, torch.from_numpy(expected), atol=5e-4, rtol=1e-5) + + @parameterized.expand([ + param(n_fft=400, hop_length=200), + param(n_fft=600, hop_length=100), + param(n_fft=200, hop_length=50), + ]) + def test_spectral_centroid(self, n_fft, hop_length): + sample_rate = 16000 + waveform = get_whitenoise( + sample_rate=sample_rate, n_channels=1).to(self.device, self.dtype) + + result = T.SpectralCentroid( + sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, + ).to(self.device, self.dtype)(waveform) + expected = librosa.feature.spectral_centroid( + y=waveform[0].cpu().numpy(), sr=sample_rate, n_fft=n_fft, hop_length=hop_length) + self.assertEqual(result, torch.from_numpy(expected), atol=5e-4, rtol=1e-5) diff --git a/test/torchaudio_unittest/transforms/sox_compatibility_test.py b/test/torchaudio_unittest/transforms/sox_compatibility_test.py new file mode 100644 index 0000000000000000000000000000000000000000..be6c9020ab9270a6842cfb65137f1983f3654424 --- /dev/null +++ b/test/torchaudio_unittest/transforms/sox_compatibility_test.py @@ -0,0 +1,88 @@ +import warnings + +import torch +import torchaudio.transforms as T +from parameterized import parameterized + +from torchaudio_unittest.common_utils import ( + skipIfNoSox, + skipIfNoExec, + TempDirMixin, + TorchaudioTestCase, + get_asset_path, + sox_utils, + load_wav, + save_wav, + get_whitenoise, +) + + +@skipIfNoSox +@skipIfNoExec('sox') +class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): + def run_sox_effect(self, input_file, effect): + output_file = self.get_temp_path('expected.wav') + sox_utils.run_sox_effect(input_file, output_file, [str(e) for e in effect]) + return load_wav(output_file) + + def assert_sox_effect(self, result, input_path, effects, atol=1e-04, rtol=1e-5): + expected, _ = self.run_sox_effect(input_path, effects) + self.assertEqual(result, expected, atol=atol, rtol=rtol) + + def get_whitenoise(self, sample_rate=8000): + noise = get_whitenoise( + sample_rate=sample_rate, duration=3, scale_factor=0.9, + ) + path = self.get_temp_path("whitenoise.wav") + save_wav(path, noise, sample_rate) + return noise, path + + @parameterized.expand([ + ('q', 'quarter_sine'), + ('h', 'half_sine'), + ('t', 'linear'), + ]) + def test_fade(self, fade_shape_sox, fade_shape): + fade_in_len, fade_out_len = 44100, 44100 + data, path = self.get_whitenoise(sample_rate=44100) + result = T.Fade(fade_in_len, fade_out_len, fade_shape)(data) + self.assert_sox_effect(result, path, ['fade', fade_shape_sox, '1', '0', '1']) + + @parameterized.expand([ + ('amplitude', 1.1), + ('db', 2), + ('power', 2), + ]) + def test_vol(self, gain_type, gain): + data, path = self.get_whitenoise() + result = T.Vol(gain, gain_type)(data) + self.assert_sox_effect(result, path, ['vol', f'{gain}', gain_type]) + + @parameterized.expand(['vad-go-stereo-44100.wav', 'vad-go-mono-32000.wav']) + def test_vad(self, filename): + path = get_asset_path(filename) + data, sample_rate = load_wav(path) + result = T.Vad(sample_rate)(data) + self.assert_sox_effect(result, path, ['vad']) + + def test_vad_warning(self): + """vad should throw a warning if input dimension is greater than 2""" + sample_rate = 41100 + + data = torch.rand(5, 5, sample_rate) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + T.Vad(sample_rate)(data) + assert len(w) == 1 + + data = torch.rand(5, sample_rate) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + T.Vad(sample_rate)(data) + assert len(w) == 0 + + data = torch.rand(sample_rate) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + T.Vad(sample_rate)(data) + assert len(w) == 0 diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_cpu_test.py b/test/torchaudio_unittest/transforms/torchscript_consistency_cpu_test.py new file mode 100644 index 0000000000000000000000000000000000000000..eb5b2d8b04dc0cf8c51aed5387c8660ec4e1cbbb --- /dev/null +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_cpu_test.py @@ -0,0 +1,14 @@ +import torch + +from torchaudio_unittest.common_utils import PytorchTestCase +from .torchscript_consistency_impl import Transforms, TransformsFloat32Only, TransformsFloat64Only + + +class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase): + dtype = torch.float32 + device = torch.device('cpu') + + +class TestTransformsFloat64(Transforms, TransformsFloat64Only, PytorchTestCase): + dtype = torch.float64 + device = torch.device('cpu') diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_cuda_test.py b/test/torchaudio_unittest/transforms/torchscript_consistency_cuda_test.py new file mode 100644 index 0000000000000000000000000000000000000000..81a81b82423d10476ad02737d1356ba1ff7300a4 --- /dev/null +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_cuda_test.py @@ -0,0 +1,16 @@ +import torch + +from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase +from .torchscript_consistency_impl import Transforms, TransformsFloat32Only, TransformsFloat64Only + + +@skipIfNoCuda +class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase): + dtype = torch.float32 + device = torch.device('cuda') + + +@skipIfNoCuda +class TestTransformsFloat64(Transforms, TransformsFloat64Only, PytorchTestCase): + dtype = torch.float64 + device = torch.device('cuda') diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..27f57adba15a57d89ffb1fe95c4ddf899a7bb2c2 --- /dev/null +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py @@ -0,0 +1,202 @@ +"""Test suites for jit-ability and its numerical compatibility""" + +import torch +import torchaudio.transforms as T +from parameterized import parameterized + +from torchaudio_unittest import common_utils +from torchaudio_unittest.common_utils import ( + skipIfRocm, + TestBaseMixin, + torch_script, +) + + +class Transforms(TestBaseMixin): + """Implements test for Transforms that are performed for different devices""" + def _assert_consistency(self, transform, tensor, *args): + tensor = tensor.to(device=self.device, dtype=self.dtype) + transform = transform.to(device=self.device, dtype=self.dtype) + + ts_transform = torch_script(transform) + + output = transform(tensor, *args) + ts_output = ts_transform(tensor, *args) + self.assertEqual(ts_output, output) + + def _assert_consistency_complex(self, transform, tensor, test_pseudo_complex=False, *args): + assert tensor.is_complex() + tensor = tensor.to(device=self.device, dtype=self.complex_dtype) + transform = transform.to(device=self.device, dtype=self.dtype) + + ts_transform = torch_script(transform) + + if test_pseudo_complex: + tensor = torch.view_as_real(tensor) + output = transform(tensor, *args) + ts_output = ts_transform(tensor, *args) + self.assertEqual(ts_output, output) + + def test_Spectrogram(self): + tensor = torch.rand((1, 1000)) + self._assert_consistency(T.Spectrogram(), tensor) + + def test_Spectrogram_return_complex(self): + tensor = torch.rand((1, 1000)) + self._assert_consistency(T.Spectrogram(power=None, return_complex=True), tensor) + + def test_InverseSpectrogram(self): + tensor = common_utils.get_whitenoise(sample_rate=8000) + spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100) + self._assert_consistency_complex(T.InverseSpectrogram(n_fft=400, hop_length=100), spectrogram) + + def test_InverseSpectrogram_pseudocomplex(self): + tensor = common_utils.get_whitenoise(sample_rate=8000) + spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100) + spectrogram = torch.view_as_real(spectrogram) + self._assert_consistency(T.InverseSpectrogram(n_fft=400, hop_length=100), spectrogram) + + @skipIfRocm + def test_GriffinLim(self): + tensor = torch.rand((1, 201, 6)) + self._assert_consistency(T.GriffinLim(length=1000, rand_init=False), tensor) + + def test_AmplitudeToDB(self): + spec = torch.rand((6, 201)) + self._assert_consistency(T.AmplitudeToDB(), spec) + + def test_MelScale(self): + spec_f = torch.rand((1, 201, 6)) + self._assert_consistency(T.MelScale(n_stft=201), spec_f) + + def test_MelSpectrogram(self): + tensor = torch.rand((1, 1000)) + self._assert_consistency(T.MelSpectrogram(), tensor) + + def test_MFCC(self): + tensor = torch.rand((1, 1000)) + self._assert_consistency(T.MFCC(), tensor) + + def test_LFCC(self): + tensor = torch.rand((1, 1000)) + self._assert_consistency(T.LFCC(), tensor) + + def test_Resample(self): + sr1, sr2 = 16000, 8000 + tensor = common_utils.get_whitenoise(sample_rate=sr1) + self._assert_consistency(T.Resample(sr1, sr2), tensor) + + def test_ComplexNorm(self): + tensor = torch.rand((1, 2, 201, 2)) + self._assert_consistency(T.ComplexNorm(), tensor) + + def test_MuLawEncoding(self): + tensor = common_utils.get_whitenoise() + self._assert_consistency(T.MuLawEncoding(), tensor) + + def test_MuLawDecoding(self): + tensor = torch.rand((1, 10)) + self._assert_consistency(T.MuLawDecoding(), tensor) + + def test_Fade(self): + waveform = common_utils.get_whitenoise() + fade_in_len = 3000 + fade_out_len = 3000 + self._assert_consistency(T.Fade(fade_in_len, fade_out_len), waveform) + + def test_FrequencyMasking(self): + tensor = torch.rand((10, 2, 50, 10, 2)) + self._assert_consistency(T.FrequencyMasking(freq_mask_param=60, iid_masks=False), tensor) + + def test_TimeMasking(self): + tensor = torch.rand((10, 2, 50, 10, 2)) + self._assert_consistency(T.TimeMasking(time_mask_param=30, iid_masks=False), tensor) + + def test_Vol(self): + waveform = common_utils.get_whitenoise() + self._assert_consistency(T.Vol(1.1), waveform) + + def test_SlidingWindowCmn(self): + tensor = torch.rand((1000, 10)) + self._assert_consistency(T.SlidingWindowCmn(), tensor) + + def test_Vad(self): + filepath = common_utils.get_asset_path("vad-go-mono-32000.wav") + waveform, sample_rate = common_utils.load_wav(filepath) + self._assert_consistency(T.Vad(sample_rate=sample_rate), waveform) + + def test_SpectralCentroid(self): + sample_rate = 44100 + waveform = common_utils.get_whitenoise(sample_rate=sample_rate) + self._assert_consistency(T.SpectralCentroid(sample_rate=sample_rate), waveform) + + @parameterized.expand([(True, ), (False, )]) + def test_TimeStretch(self, test_pseudo_complex): + n_freq = 400 + hop_length = 512 + fixed_rate = 1.3 + tensor = torch.view_as_complex(torch.rand((10, 2, n_freq, 10, 2))) + self._assert_consistency_complex( + T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate), + tensor, + test_pseudo_complex + ) + + def test_PitchShift(self): + sample_rate = 8000 + n_steps = 4 + waveform = common_utils.get_whitenoise(sample_rate=sample_rate) + self._assert_consistency( + T.PitchShift(sample_rate=sample_rate, n_steps=n_steps), + waveform + ) + + def test_PSD(self): + tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4) + spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100) + spectrogram = spectrogram.to(self.device) + self._assert_consistency_complex(T.PSD(), spectrogram) + + def test_PSD_with_mask(self): + tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4) + spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100) + spectrogram = spectrogram.to(self.device) + mask = torch.rand(spectrogram.shape[-2:], device=self.device) + self._assert_consistency_complex(T.PSD(), spectrogram, False, mask) + + +class TransformsFloat32Only(TestBaseMixin): + def test_rnnt_loss(self): + logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.6, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.8, 0.1]], + [[0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.1, 0.1], + [0.7, 0.1, 0.2, 0.1, 0.1]]]]) + tensor = logits.to(device=self.device, dtype=torch.float32) + targets = torch.tensor([[1, 2]], device=tensor.device, dtype=torch.int32) + logit_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32) + target_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32) + + self._assert_consistency(T.RNNTLoss(), logits, targets, logit_lengths, target_lengths) + + +class TransformsFloat64Only(TestBaseMixin): + @parameterized.expand([ + ["ref_channel", True], + ["stv_evd", True], + ["stv_power", True], + ["ref_channel", False], + ["stv_evd", False], + ["stv_power", False], + ]) + def test_MVDR(self, solution, online): + tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4) + spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100) + spectrogram = spectrogram.to(device=self.device, dtype=torch.cdouble) + mask_s = torch.rand(spectrogram.shape[-2:], device=self.device) + mask_n = torch.rand(spectrogram.shape[-2:], device=self.device) + self._assert_consistency_complex( + T.MVDR(solution=solution, online=online), + spectrogram, False, mask_s, mask_n + ) diff --git a/test/torchaudio_unittest/transforms/transforms_cpu_test.py b/test/torchaudio_unittest/transforms/transforms_cpu_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5177b9a455864b742c2acf936512cd1792de79da --- /dev/null +++ b/test/torchaudio_unittest/transforms/transforms_cpu_test.py @@ -0,0 +1,14 @@ +import torch + +from torchaudio_unittest.common_utils import PytorchTestCase +from . transforms_test_impl import TransformsTestBase + + +class TransformsCPUFloat32Test(TransformsTestBase, PytorchTestCase): + device = 'cpu' + dtype = torch.float32 + + +class TransformsCPUFloat64Test(TransformsTestBase, PytorchTestCase): + device = 'cpu' + dtype = torch.float64 diff --git a/test/torchaudio_unittest/transforms/transforms_cuda_test.py b/test/torchaudio_unittest/transforms/transforms_cuda_test.py new file mode 100644 index 0000000000000000000000000000000000000000..948966038483438bf7fcc857ae3e68ac4078f378 --- /dev/null +++ b/test/torchaudio_unittest/transforms/transforms_cuda_test.py @@ -0,0 +1,19 @@ +import torch + +from torchaudio_unittest.common_utils import ( + PytorchTestCase, + skipIfNoCuda, +) +from . transforms_test_impl import TransformsTestBase + + +@skipIfNoCuda +class TransformsCUDAFloat32Test(TransformsTestBase, PytorchTestCase): + device = 'cuda' + dtype = torch.float32 + + +@skipIfNoCuda +class TransformsCUDAFloat64Test(TransformsTestBase, PytorchTestCase): + device = 'cuda' + dtype = torch.float64 diff --git a/test/torchaudio_unittest/transforms/transforms_test.py b/test/torchaudio_unittest/transforms/transforms_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7808789cc6ae442dd2f07f9f35545a6c43642314 --- /dev/null +++ b/test/torchaudio_unittest/transforms/transforms_test.py @@ -0,0 +1,314 @@ +import math + +import torch +import torchaudio +import torchaudio.transforms as transforms +import torchaudio.functional as F + +from torchaudio_unittest import common_utils + + +class Tester(common_utils.TorchaudioTestCase): + backend = 'default' + + # create a sinewave signal for testing + sample_rate = 16000 + freq = 440 + volume = .3 + waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate)) + waveform.unsqueeze_(0) # (1, 64000) + waveform = (waveform * volume * 2**31).long() + + def scale(self, waveform, factor=2.0**31): + # scales a waveform by a factor + if not waveform.is_floating_point(): + waveform = waveform.to(torch.get_default_dtype()) + return waveform / factor + + def test_mu_law_companding(self): + + quantization_channels = 256 + + waveform = self.waveform.clone() + if not waveform.is_floating_point(): + waveform = waveform.to(torch.get_default_dtype()) + waveform /= torch.abs(waveform).max() + + self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.) + + waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform) + self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels) + + waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu) + self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.) + + def test_AmplitudeToDB(self): + filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav') + waveform = common_utils.load_wav(filepath)[0] + + mag_to_db_transform = transforms.AmplitudeToDB('magnitude', 80.) + power_to_db_transform = transforms.AmplitudeToDB('power', 80.) + + mag_to_db_torch = mag_to_db_transform(torch.abs(waveform)) + power_to_db_torch = power_to_db_transform(torch.pow(waveform, 2)) + + self.assertEqual(mag_to_db_torch, power_to_db_torch) + + def test_melscale_load_save(self): + specgram = torch.ones(1, 201, 100) + melscale_transform = transforms.MelScale() + melscale_transform(specgram) + + melscale_transform_copy = transforms.MelScale() + melscale_transform_copy.load_state_dict(melscale_transform.state_dict()) + + fb = melscale_transform.fb + fb_copy = melscale_transform_copy.fb + + self.assertEqual(fb_copy.size(), (201, 128)) + self.assertEqual(fb, fb_copy) + + def test_melspectrogram_load_save(self): + waveform = self.waveform.float() + mel_spectrogram_transform = transforms.MelSpectrogram() + mel_spectrogram_transform(waveform) + + mel_spectrogram_transform_copy = transforms.MelSpectrogram() + mel_spectrogram_transform_copy.load_state_dict(mel_spectrogram_transform.state_dict()) + + window = mel_spectrogram_transform.spectrogram.window + window_copy = mel_spectrogram_transform_copy.spectrogram.window + + fb = mel_spectrogram_transform.mel_scale.fb + fb_copy = mel_spectrogram_transform_copy.mel_scale.fb + + self.assertEqual(window, window_copy) + # the default for n_fft = 400 and n_mels = 128 + self.assertEqual(fb_copy.size(), (201, 128)) + self.assertEqual(fb, fb_copy) + + def test_mel2(self): + top_db = 80. + s2db = transforms.AmplitudeToDB('power', top_db) + + waveform = self.waveform.clone() # (1, 16000) + waveform_scaled = self.scale(waveform) # (1, 16000) + mel_transform = transforms.MelSpectrogram() + # check defaults + spectrogram_torch = s2db(mel_transform(waveform_scaled)) # (1, 128, 321) + self.assertTrue(spectrogram_torch.dim() == 3) + self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) + self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels) + # check correctness of filterbank conversion matrix + self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all()) + self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all()) + # check options + kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500, + 'hop_length': 125, 'n_fft': 800, 'n_mels': 50} + mel_transform2 = transforms.MelSpectrogram(**kwargs) + spectrogram2_torch = s2db(mel_transform2(waveform_scaled)) # (1, 50, 513) + self.assertTrue(spectrogram2_torch.dim() == 3) + self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) + self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels) + self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all()) + self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all()) + # check on multi-channel audio + filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav') + x_stereo = common_utils.load_wav(filepath)[0] # (2, 278756), 44100 + spectrogram_stereo = s2db(mel_transform(x_stereo)) # (2, 128, 1394) + self.assertTrue(spectrogram_stereo.dim() == 3) + self.assertTrue(spectrogram_stereo.size(0) == 2) + self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) + self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels) + # check filterbank matrix creation + fb_matrix_transform = transforms.MelScale( + n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400) + self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all()) + self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all()) + self.assertEqual(fb_matrix_transform.fb.size(), (400, 100)) + + def test_mfcc_defaults(self): + """Check the default configuration of the MFCC transform. + """ + sample_rate = 16000 + audio = common_utils.get_whitenoise(sample_rate=sample_rate) + + n_mfcc = 40 + mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate, + n_mfcc=n_mfcc, + norm='ortho') + torch_mfcc = mfcc_transform(audio) # (1, 40, 81) + self.assertEqual(torch_mfcc.dim(), 3) + self.assertEqual(torch_mfcc.shape[1], n_mfcc) + self.assertEqual(torch_mfcc.shape[2], 81) + + def test_mfcc_kwargs_passthrough(self): + """Check kwargs get correctly passed to the MelSpectrogram transform. + """ + sample_rate = 16000 + audio = common_utils.get_whitenoise(sample_rate=sample_rate) + + n_mfcc = 40 + melkwargs = {'win_length': 200} + mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate, + n_mfcc=n_mfcc, + norm='ortho', + melkwargs=melkwargs) + torch_mfcc = mfcc_transform(audio) # (1, 40, 161) + self.assertEqual(torch_mfcc.shape[2], 161) + + def test_mfcc_norms(self): + """Check if MFCC-DCT norms work correctly. + """ + sample_rate = 16000 + audio = common_utils.get_whitenoise(sample_rate=sample_rate) + + n_mfcc = 40 + n_mels = 128 + mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate, + n_mfcc=n_mfcc, + norm='ortho') + # check norms work correctly + mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate, + n_mfcc=n_mfcc, + norm=None) + torch_mfcc_norm_none = mfcc_transform_norm_none(audio) # (1, 40, 81) + + norm_check = mfcc_transform(audio) + norm_check[:, 0, :] *= math.sqrt(n_mels) * 2 + norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2 + + self.assertEqual(torch_mfcc_norm_none, norm_check) + + def test_lfcc_defaults(self): + """Check default settings for LFCC transform. + """ + sample_rate = 16000 + audio = common_utils.get_whitenoise(sample_rate=sample_rate) + + n_lfcc = 40 + n_filter = 128 + lfcc_transform = torchaudio.transforms.LFCC(sample_rate=sample_rate, + n_filter=n_filter, + n_lfcc=n_lfcc, + norm='ortho') + torch_lfcc = lfcc_transform(audio) # (1, 40, 81) + self.assertEqual(torch_lfcc.dim(), 3) + self.assertEqual(torch_lfcc.shape[1], n_lfcc) + self.assertEqual(torch_lfcc.shape[2], 81) + + def test_lfcc_arg_passthrough(self): + """Check if kwargs get correctly passed to the underlying Spectrogram transform. + """ + sample_rate = 16000 + audio = common_utils.get_whitenoise(sample_rate=sample_rate) + + n_lfcc = 40 + n_filter = 128 + speckwargs = {'win_length': 200} + lfcc_transform = torchaudio.transforms.LFCC(sample_rate=sample_rate, + n_filter=n_filter, + n_lfcc=n_lfcc, + norm='ortho', + speckwargs=speckwargs) + torch_lfcc = lfcc_transform(audio) # (1, 40, 161) + self.assertEqual(torch_lfcc.shape[2], 161) + + def test_lfcc_norms(self): + """Check if LFCC-DCT norm works correctly. + """ + sample_rate = 16000 + audio = common_utils.get_whitenoise(sample_rate=sample_rate) + + n_lfcc = 40 + n_filter = 128 + lfcc_transform = torchaudio.transforms.LFCC(sample_rate=sample_rate, + n_filter=n_filter, + n_lfcc=n_lfcc, + norm='ortho') + + lfcc_transform_norm_none = torchaudio.transforms.LFCC(sample_rate=sample_rate, + n_filter=n_filter, + n_lfcc=n_lfcc, + norm=None) + torch_lfcc_norm_none = lfcc_transform_norm_none(audio) # (1, 40, 161) + + norm_check = lfcc_transform(audio) # (1, 40, 161) + norm_check[:, 0, :] *= math.sqrt(n_filter) * 2 + norm_check[:, 1:, :] *= math.sqrt(n_filter / 2) * 2 + + self.assertEqual(torch_lfcc_norm_none, norm_check) + + def test_resample_size(self): + input_path = common_utils.get_asset_path('sinewave.wav') + waveform, sample_rate = common_utils.load_wav(input_path) + + upsample_rate = sample_rate * 2 + downsample_rate = sample_rate // 2 + invalid_resampling_method = 'foo' + + with self.assertRaises(ValueError): + torchaudio.transforms.Resample(sample_rate, upsample_rate, + resampling_method=invalid_resampling_method) + + upsample_resample = torchaudio.transforms.Resample( + sample_rate, upsample_rate, resampling_method='sinc_interpolation') + up_sampled = upsample_resample(waveform) + + # we expect the upsampled signal to have twice as many samples + self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2) + + downsample_resample = torchaudio.transforms.Resample( + sample_rate, downsample_rate, resampling_method='sinc_interpolation') + down_sampled = downsample_resample(waveform) + + # we expect the downsampled signal to have half as many samples + self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2) + + def test_compute_deltas(self): + channel = 13 + n_mfcc = channel * 3 + time = 1021 + win_length = 2 * 7 + 1 + specgram = torch.randn(channel, n_mfcc, time) + transform = transforms.ComputeDeltas(win_length=win_length) + computed = transform(specgram) + self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) + + def test_compute_deltas_transform_same_as_functional(self, atol=1e-6, rtol=1e-8): + channel = 13 + n_mfcc = channel * 3 + time = 1021 + win_length = 2 * 7 + 1 + specgram = torch.randn(channel, n_mfcc, time) + + transform = transforms.ComputeDeltas(win_length=win_length) + computed_transform = transform(specgram) + + computed_functional = F.compute_deltas(specgram, win_length=win_length) + self.assertEqual(computed_functional, computed_transform, atol=atol, rtol=rtol) + + def test_compute_deltas_twochannel(self): + specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1) + expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5], + [0.5, 1.0, 1.0, 0.5]]]) + transform = transforms.ComputeDeltas(win_length=3) + computed = transform(specgram) + assert computed.shape == expected.shape, (computed.shape, expected.shape) + self.assertEqual(computed, expected, atol=1e-6, rtol=1e-8) + + +class SmokeTest(common_utils.TorchaudioTestCase): + + def test_spectrogram(self): + specgram = transforms.Spectrogram(center=False, pad_mode="reflect", onesided=False) + self.assertEqual(specgram.center, False) + self.assertEqual(specgram.pad_mode, "reflect") + self.assertEqual(specgram.onesided, False) + + def test_melspectrogram(self): + melspecgram = transforms.MelSpectrogram(center=True, pad_mode="reflect", onesided=False) + specgram = melspecgram.spectrogram + self.assertEqual(specgram.center, True) + self.assertEqual(specgram.pad_mode, "reflect") + self.assertEqual(specgram.onesided, False) diff --git a/test/torchaudio_unittest/transforms/transforms_test_impl.py b/test/torchaudio_unittest/transforms/transforms_test_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..9aa5496a18531416c7e3f2a0cad643f9963ff271 --- /dev/null +++ b/test/torchaudio_unittest/transforms/transforms_test_impl.py @@ -0,0 +1,134 @@ +import torch +import torchaudio.transforms as T +from parameterized import parameterized, param +from torchaudio_unittest.common_utils import ( + TestBaseMixin, + get_whitenoise, + get_spectrogram, + nested_params, +) +from torchaudio_unittest.common_utils.psd_utils import psd_numpy + + +def _get_ratio(mat): + return (mat.sum() / mat.numel()).item() + + +class TransformsTestBase(TestBaseMixin): + def test_InverseMelScale(self): + """Gauge the quality of InverseMelScale transform. + + As InverseMelScale is currently implemented with + random initialization + iterative optimization, + it is not practically possible to assert the difference between + the estimated spectrogram and the original spectrogram as a whole. + Estimated spectrogram has very huge descrepency locally. + Thus in this test we gauge what percentage of elements are bellow + certain tolerance. + At the moment, the quality of estimated spectrogram is not good. + When implementation is changed in a way it makes the quality even worse, + this test will fail. + """ + n_fft = 400 + power = 1 + n_mels = 64 + sample_rate = 8000 + + n_stft = n_fft // 2 + 1 + + # Generate reference spectrogram and input mel-scaled spectrogram + expected = get_spectrogram( + get_whitenoise(sample_rate=sample_rate, duration=1, n_channels=2), + n_fft=n_fft, power=power).to(self.device, self.dtype) + input = T.MelScale( + n_mels=n_mels, sample_rate=sample_rate, n_stft=n_stft + ).to(self.device, self.dtype)(expected) + + # Run transform + transform = T.InverseMelScale( + n_stft, n_mels=n_mels, sample_rate=sample_rate).to(self.device, self.dtype) + torch.random.manual_seed(0) + result = transform(input) + + # Compare + epsilon = 1e-60 + relative_diff = torch.abs((result - expected) / (expected + epsilon)) + + for tol in [1e-1, 1e-3, 1e-5, 1e-10]: + print( + f"Ratio of relative diff smaller than {tol:e} is " + f"{_get_ratio(relative_diff < tol)}") + assert _get_ratio(relative_diff < 1e-1) > 0.2 + assert _get_ratio(relative_diff < 1e-3) > 5e-3 + assert _get_ratio(relative_diff < 1e-5) > 1e-5 + + @nested_params( + ["sinc_interpolation", "kaiser_window"], + [16000, 44100], + ) + def test_resample_identity(self, resampling_method, sample_rate): + """When sampling rate is not changed, the transform returns an identical Tensor""" + waveform = get_whitenoise(sample_rate=sample_rate, duration=1) + + resampler = T.Resample(sample_rate, sample_rate, resampling_method) + resampled = resampler(waveform) + self.assertEqual(waveform, resampled) + + @nested_params( + ["sinc_interpolation", "kaiser_window"], + [None, torch.float64], + ) + def test_resample_cache_dtype(self, resampling_method, dtype): + """Providing dtype changes the kernel cache dtype""" + transform = T.Resample(16000, 44100, resampling_method, dtype=dtype) + + assert transform.kernel.dtype == dtype if dtype is not None else torch.float32 + + @parameterized.expand([ + param(n_fft=300, center=True, onesided=True), + param(n_fft=400, center=True, onesided=False), + param(n_fft=400, center=True, onesided=False), + param(n_fft=300, center=True, onesided=False), + param(n_fft=400, hop_length=10), + param(n_fft=800, win_length=400, hop_length=20), + param(n_fft=800, win_length=400, hop_length=20, normalized=True), + param(), + param(n_fft=400, pad=32), + # These tests do not work - cause runtime error + # See https://github.com/pytorch/pytorch/issues/62323 + # param(n_fft=400, center=False, onesided=True), + # param(n_fft=400, center=False, onesided=False), + ]) + def test_roundtrip_spectrogram(self, **args): + """Test the spectrogram + inverse spectrogram results in approximate identity.""" + + waveform = get_whitenoise(sample_rate=8000, duration=0.5, dtype=self.dtype) + + s = T.Spectrogram(**args, power=None) + inv_s = T.InverseSpectrogram(**args) + transformed = s.forward(waveform) + restored = inv_s.forward(transformed, length=waveform.shape[-1]) + self.assertEqual(waveform, restored, atol=1e-6, rtol=1e-6) + + @parameterized.expand([ + param(0.5, 1, True, False), + param(0.5, 1, None, False), + param(1, 4, True, True), + param(1, 6, None, True), + ]) + def test_psd(self, duration, channel, mask, multi_mask): + """Providing dtype changes the kernel cache dtype""" + transform = T.PSD(multi_mask) + waveform = get_whitenoise(sample_rate=8000, duration=duration, n_channels=channel) + spectrogram = get_spectrogram(waveform, n_fft=400) # (channel, freq, time) + spectrogram = spectrogram.to(torch.cdouble) + if mask is not None: + if multi_mask: + mask = torch.rand(spectrogram.shape[-3:]) + else: + mask = torch.rand(spectrogram.shape[-2:]) + psd_np = psd_numpy(spectrogram.detach().numpy(), mask.detach().numpy(), multi_mask) + else: + psd_np = psd_numpy(spectrogram.detach().numpy(), mask, multi_mask) + psd = transform(spectrogram, mask) + self.assertEqual(psd, psd_np, atol=1e-5, rtol=1e-5) diff --git a/test/torchaudio_unittest/utils/__init__.py b/test/torchaudio_unittest/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/torchaudio_unittest/utils/sox_utils_test.py b/test/torchaudio_unittest/utils/sox_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7bd901843764f5184ef455821729cdcfde2c5a9a --- /dev/null +++ b/test/torchaudio_unittest/utils/sox_utils_test.py @@ -0,0 +1,49 @@ +from torchaudio.utils import sox_utils + +from torchaudio_unittest.common_utils import ( + PytorchTestCase, + skipIfNoSox, +) + + +@skipIfNoSox +class TestSoxUtils(PytorchTestCase): + """Smoke tests for sox_util module""" + def test_set_seed(self): + """`set_seed` does not crush""" + sox_utils.set_seed(0) + + def test_set_verbosity(self): + """`set_verbosity` does not crush""" + for val in range(6, 0, -1): + sox_utils.set_verbosity(val) + + def test_set_buffer_size(self): + """`set_buffer_size` does not crush""" + sox_utils.set_buffer_size(131072) + # back to default + sox_utils.set_buffer_size(8192) + + def test_set_use_threads(self): + """`set_use_threads` does not crush""" + sox_utils.set_use_threads(True) + # back to default + sox_utils.set_use_threads(False) + + def test_list_effects(self): + """`list_effects` returns the list of available effects""" + effects = sox_utils.list_effects() + # We cannot infer what effects are available, so only check some of them. + assert 'highpass' in effects + assert 'phaser' in effects + assert 'gain' in effects + + def test_list_read_formats(self): + """`list_read_formats` returns the list of supported formats""" + formats = sox_utils.list_read_formats() + assert 'wav' in formats + + def test_list_write_formats(self): + """`list_write_formats` returns the list of supported formats""" + formats = sox_utils.list_write_formats() + assert 'opus' not in formats diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..ef984d57b618ec6516b1d9d3f13af0be5697bf8a --- /dev/null +++ b/third_party/CMakeLists.txt @@ -0,0 +1,24 @@ +set(TORCHAUDIO_THIRD_PARTIES "") + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden") + +################################################################################ +# sox +################################################################################ +add_library(libsox INTERFACE) +if (BUILD_SOX) + add_subdirectory(sox) + target_include_directories(libsox INTERFACE ${SOX_INCLUDE_DIR}) + target_link_libraries(libsox INTERFACE ${SOX_LIBRARIES}) + list(APPEND TORCHAUDIO_THIRD_PARTIES libsox) +endif() + +################################################################################ +# kaldi +################################################################################ +if (BUILD_KALDI) + add_subdirectory(kaldi) + list(APPEND TORCHAUDIO_THIRD_PARTIES kaldi) +endif() + +set_property(GLOBAL PROPERTY TORCHAUDIO_THIRD_PARTIES "${TORCHAUDIO_THIRD_PARTIES}") diff --git a/third_party/kaldi/CMakeLists.txt b/third_party/kaldi/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a121c7b87d02b46cc3cbef6fe6c45e0fbaa84fa8 --- /dev/null +++ b/third_party/kaldi/CMakeLists.txt @@ -0,0 +1,32 @@ +set(KALDI_REPO ${CMAKE_CURRENT_SOURCE_DIR}/submodule) + +if (NOT EXISTS ${KALDI_REPO}/src/base/version.h) +# Apply custom patch +execute_process( + WORKING_DIRECTORY ${KALDI_REPO} + COMMAND "git" "checkout" "." + ) +execute_process( + WORKING_DIRECTORY ${KALDI_REPO} + COMMAND git apply ../kaldi.patch + ) +# Update the version string +execute_process( + WORKING_DIRECTORY ${KALDI_REPO}/src/base + COMMAND sh get_version.sh + ) +endif() + +set(KALDI_SOURCES + src/matrix/kaldi-vector.cc + src/matrix/kaldi-matrix.cc + submodule/src/base/kaldi-error.cc + submodule/src/base/kaldi-math.cc + submodule/src/feat/feature-functions.cc + submodule/src/feat/pitch-functions.cc + submodule/src/feat/resample.cc + ) + +add_library(kaldi STATIC ${KALDI_SOURCES}) +target_include_directories(kaldi PUBLIC src submodule/src) +target_include_directories(kaldi PRIVATE ${TORCH_INCLUDE_DIRS}) diff --git a/third_party/kaldi/README.md b/third_party/kaldi/README.md new file mode 100644 index 0000000000000000000000000000000000000000..58c48747a8e7666a39437a7d38e8d7123dec83c9 --- /dev/null +++ b/third_party/kaldi/README.md @@ -0,0 +1,6 @@ +# Custom Kaldi build + +This directory contains original Kaldi repository (as submodule), [the custom implementation of Kaldi's vector/matrix](./src) and the build script. + +We use the custom build process so that the resulting library only contains what torchaudio needs. +We use the custom vector/matrix implementation so that we can use the same BLAS library that PyTorch is compiled with, and so that we can (hopefully, in future) take advantage of other PyTorch features (such as differentiability and GPU support). The down side of this approach is that it adds a lot of overhead compared to the original Kaldi (operator dispatch and element-wise processing, which PyTorch is not efficient at). We can improve this gradually, and if you are interested in helping, please let us know by opening an issue. \ No newline at end of file diff --git a/third_party/kaldi/kaldi.patch b/third_party/kaldi/kaldi.patch new file mode 100644 index 0000000000000000000000000000000000000000..40667bced881e4a2a57048822467d99cf1284e6c --- /dev/null +++ b/third_party/kaldi/kaldi.patch @@ -0,0 +1,76 @@ +diff --git a/src/base/kaldi-types.h b/src/base/kaldi-types.h +index 7ebf4f853..c15b288b2 100644 +--- a/src/base/kaldi-types.h ++++ b/src/base/kaldi-types.h +@@ -41,6 +41,7 @@ typedef float BaseFloat; + + // for discussion on what to do if you need compile kaldi + // without OpenFST, see the bottom of this this file ++/* + #include + + namespace kaldi { +@@ -53,10 +54,10 @@ namespace kaldi { + typedef float float32; + typedef double double64; + } // end namespace kaldi ++*/ + + // In a theoretical case you decide compile Kaldi without the OpenFST + // comment the previous namespace statement and uncomment the following +-/* + namespace kaldi { + typedef int8_t int8; + typedef int16_t int16; +@@ -70,6 +71,5 @@ namespace kaldi { + typedef float float32; + typedef double double64; + } // end namespace kaldi +-*/ + + #endif // KALDI_BASE_KALDI_TYPES_H_ +diff --git a/src/matrix/matrix-lib.h b/src/matrix/matrix-lib.h +index b6059b06c..4fb9e1b16 100644 +--- a/src/matrix/matrix-lib.h ++++ b/src/matrix/matrix-lib.h +@@ -25,14 +25,14 @@ + #include "base/kaldi-common.h" + #include "matrix/kaldi-vector.h" + #include "matrix/kaldi-matrix.h" +-#include "matrix/sp-matrix.h" +-#include "matrix/tp-matrix.h" ++// #include "matrix/sp-matrix.h" ++// #include "matrix/tp-matrix.h" + #include "matrix/matrix-functions.h" + #include "matrix/srfft.h" + #include "matrix/compressed-matrix.h" +-#include "matrix/sparse-matrix.h" ++// #include "matrix/sparse-matrix.h" + #include "matrix/optimization.h" +-#include "matrix/numpy-array.h" ++// #include "matrix/numpy-array.h" + + #endif + +diff --git a/src/util/common-utils.h b/src/util/common-utils.h +index cfb0c255c..48d199e97 100644 +--- a/src/util/common-utils.h ++++ b/src/util/common-utils.h +@@ -21,11 +21,11 @@ + + #include "base/kaldi-common.h" + #include "util/parse-options.h" +-#include "util/kaldi-io.h" +-#include "util/simple-io-funcs.h" +-#include "util/kaldi-holder.h" +-#include "util/kaldi-table.h" +-#include "util/table-types.h" +-#include "util/text-utils.h" ++// #include "util/kaldi-io.h" ++// #include "util/simple-io-funcs.h" ++// #include "util/kaldi-holder.h" ++// #include "util/kaldi-table.h" ++// #include "util/table-types.h" ++// #include "util/text-utils.h" + + #endif // KALDI_UTIL_COMMON_UTILS_H_ diff --git a/third_party/kaldi/src/matrix/kaldi-matrix.cc b/third_party/kaldi/src/matrix/kaldi-matrix.cc new file mode 100644 index 0000000000000000000000000000000000000000..a89c3809c9137db3a21c15b8147bb369fbf3d5e6 --- /dev/null +++ b/third_party/kaldi/src/matrix/kaldi-matrix.cc @@ -0,0 +1,39 @@ +#include "matrix/kaldi-matrix.h" +#include + +namespace { + +template +void assert_matrix_shape(const torch::Tensor& tensor_); + +template <> +void assert_matrix_shape(const torch::Tensor& tensor_) { + TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 2); + TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat32); + TORCH_CHECK(tensor_.device().is_cpu(), "Input tensor has to be on CPU."); +} + +template <> +void assert_matrix_shape(const torch::Tensor& tensor_) { + TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 2); + TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat64); + TORCH_CHECK(tensor_.device().is_cpu(), "Input tensor has to be on CPU."); +} + +} // namespace + +namespace kaldi { + +template +MatrixBase::MatrixBase(torch::Tensor tensor) : tensor_(tensor) { + assert_matrix_shape(tensor_); +}; + +template class Matrix; +template class Matrix; +template class MatrixBase; +template class MatrixBase; +template class SubMatrix; +template class SubMatrix; + +} // namespace kaldi diff --git a/third_party/kaldi/src/matrix/kaldi-matrix.h b/third_party/kaldi/src/matrix/kaldi-matrix.h new file mode 100644 index 0000000000000000000000000000000000000000..f64828b84f4e96ed2820802e0dc68fc8a9389a5c --- /dev/null +++ b/third_party/kaldi/src/matrix/kaldi-matrix.h @@ -0,0 +1,178 @@ +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h + +#ifndef KALDI_MATRIX_KALDI_MATRIX_H_ +#define KALDI_MATRIX_KALDI_MATRIX_H_ + +#include +#include "matrix/kaldi-vector.h" +#include "matrix/matrix-common.h" + +using namespace torch::indexing; + +namespace kaldi { + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L44-L48 +template +class MatrixBase { + public: + //////////////////////////////////////////////////////////////////////////////// + // PyTorch-specific items + //////////////////////////////////////////////////////////////////////////////// + torch::Tensor tensor_; + /// Construct VectorBase which is an interface to an existing torch::Tensor + /// object. + MatrixBase(torch::Tensor tensor); + + //////////////////////////////////////////////////////////////////////////////// + // Kaldi-compatible items + //////////////////////////////////////////////////////////////////////////////// + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L62-L63 + inline MatrixIndexT NumRows() const { + return tensor_.size(0); + }; + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L65-L66 + inline MatrixIndexT NumCols() const { + return tensor_.size(1); + }; + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L177-L178 + void CopyColFromVec(const VectorBase& v, const MatrixIndexT col) { + tensor_.index_put_({Slice(), col}, v.tensor_); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L99-L107 + inline Real& operator()(MatrixIndexT r, MatrixIndexT c) { + // CPU only + return tensor_.accessor()[r][c]; + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L112-L120 + inline const Real operator()(MatrixIndexT r, MatrixIndexT c) const { + return tensor_.index({Slice(r), Slice(c)}).item().template to(); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L138-L141 + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.cc#L859-L898 + template + void CopyFromMat( + const MatrixBase& M, + MatrixTransposeType trans = kNoTrans) { + auto src = M.tensor_; + if (trans == kTrans) + src = src.transpose(1, 0); + tensor_.index_put_({Slice(), Slice()}, src); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L186-L191 + inline const SubVector Row(MatrixIndexT i) const { + return SubVector(*this, i); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L208-L211 + inline SubMatrix RowRange( + const MatrixIndexT row_offset, + const MatrixIndexT num_rows) const { + return SubMatrix(*this, row_offset, num_rows, 0, NumCols()); + } + + protected: + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L749-L753 + explicit MatrixBase() : tensor_(torch::empty({0, 0})) { + KALDI_ASSERT_IS_FLOATING_TYPE(Real); + } +}; + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L781-L784 +template +class Matrix : public MatrixBase { + public: + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L786-L787 + Matrix() : MatrixBase() {} + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L789-L793 + Matrix( + const MatrixIndexT r, + const MatrixIndexT c, + MatrixResizeType resize_type = kSetZero, + MatrixStrideType stride_type = kDefaultStride) + : MatrixBase() { + Resize(r, c, resize_type, stride_type); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L808-L811 + explicit Matrix( + const MatrixBase& M, + MatrixTransposeType trans = kNoTrans) + : MatrixBase( + trans == kNoTrans ? M.tensor_ : M.tensor_.transpose(1, 0)) {} + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L816-L819 + template + explicit Matrix( + const MatrixBase& M, + MatrixTransposeType trans = kNoTrans) + : MatrixBase( + trans == kNoTrans ? M.tensor_ : M.tensor_.transpose(1, 0)) {} + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L859-L874 + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.cc#L817-L857 + void Resize( + const MatrixIndexT r, + const MatrixIndexT c, + MatrixResizeType resize_type = kSetZero, + MatrixStrideType stride_type = kDefaultStride) { + auto& tensor_ = MatrixBase::tensor_; + switch (resize_type) { + case kSetZero: + tensor_.resize_({r, c}).zero_(); + break; + case kUndefined: + tensor_.resize_({r, c}); + break; + case kCopyData: + auto tmp = tensor_; + auto tmp_rows = tmp.size(0); + auto tmp_cols = tmp.size(1); + tensor_.resize_({r, c}).zero_(); + auto rows = Slice(None, r < tmp_rows ? r : tmp_rows); + auto cols = Slice(None, c < tmp_cols ? c : tmp_cols); + tensor_.index_put_({rows, cols}, tmp.index({rows, cols})); + break; + } + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L876-L883 + Matrix& operator=(const MatrixBase& other) { + if (MatrixBase::NumRows() != other.NumRows() || + MatrixBase::NumCols() != other.NumCols()) + Resize(other.NumRows(), other.NumCols(), kUndefined); + MatrixBase::CopyFromMat(other); + return *this; + } +}; + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L940-L948 +template +class SubMatrix : public MatrixBase { + public: + SubMatrix( + const MatrixBase& T, + const MatrixIndexT ro, // row offset, 0 < ro < NumRows() + const MatrixIndexT r, // number of rows, r > 0 + const MatrixIndexT co, // column offset, 0 < co < NumCols() + const MatrixIndexT c) // number of columns, c > 0 + : MatrixBase( + T.tensor_.index({Slice(ro, ro + r), Slice(co, co + c)})) {} +}; + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L1059-L1060 +template +std::ostream& operator<<(std::ostream& Out, const MatrixBase& M) { + Out << M.tensor_; + return Out; +} + +} // namespace kaldi + +#endif diff --git a/third_party/kaldi/src/matrix/kaldi-vector.cc b/third_party/kaldi/src/matrix/kaldi-vector.cc new file mode 100644 index 0000000000000000000000000000000000000000..df59f13a369f52813ac19556ab1948bc70487365 --- /dev/null +++ b/third_party/kaldi/src/matrix/kaldi-vector.cc @@ -0,0 +1,42 @@ +#include "matrix/kaldi-vector.h" +#include +#include "matrix/kaldi-matrix.h" + +namespace { + +template +void assert_vector_shape(const torch::Tensor& tensor_); + +template <> +void assert_vector_shape(const torch::Tensor& tensor_) { + TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 1); + TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat32); + TORCH_CHECK(tensor_.device().is_cpu(), "Input tensor has to be on CPU."); +} + +template <> +void assert_vector_shape(const torch::Tensor& tensor_) { + TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 1); + TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat64); + TORCH_CHECK(tensor_.device().is_cpu(), "Input tensor has to be on CPU."); +} + +} // namespace + +namespace kaldi { + +template +VectorBase::VectorBase(torch::Tensor tensor) + : tensor_(tensor), data_(tensor.data_ptr()) { + assert_vector_shape(tensor_); +}; + +template +VectorBase::VectorBase() : VectorBase(torch::empty({0})) {} + +template class Vector; +template class Vector; +template class VectorBase; +template class VectorBase; + +} // namespace kaldi diff --git a/third_party/kaldi/src/matrix/kaldi-vector.h b/third_party/kaldi/src/matrix/kaldi-vector.h new file mode 100644 index 0000000000000000000000000000000000000000..ea11ca4ddde2a6038e62f92be34844681f8e418b --- /dev/null +++ b/third_party/kaldi/src/matrix/kaldi-vector.h @@ -0,0 +1,319 @@ +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h + +#ifndef KALDI_MATRIX_KALDI_VECTOR_H_ +#define KALDI_MATRIX_KALDI_VECTOR_H_ + +#include +#include "matrix/matrix-common.h" + +using namespace torch::indexing; + +namespace kaldi { + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L36-L40 +template +class VectorBase { + public: + //////////////////////////////////////////////////////////////////////////////// + // PyTorch-specific things + //////////////////////////////////////////////////////////////////////////////// + torch::Tensor tensor_; + + /// Construct VectorBase which is an interface to an existing torch::Tensor + /// object. + VectorBase(torch::Tensor tensor); + + //////////////////////////////////////////////////////////////////////////////// + // Kaldi-compatible methods + //////////////////////////////////////////////////////////////////////////////// + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L42-L43 + void SetZero() { + Set(0); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L48-L49 + void Set(Real f) { + tensor_.index_put_({"..."}, f); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L62-L63 + inline MatrixIndexT Dim() const { + return tensor_.numel(); + }; + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L68-L69 + inline Real* Data() { + return data_; + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L71-L72 + inline const Real* Data() const { + return data_; + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L74-L79 + inline Real operator()(MatrixIndexT i) const { + return data_[i]; + }; + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L81-L86 + inline Real& operator()(MatrixIndexT i) { + return tensor_.accessor()[i]; + }; + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L88-L95 + SubVector Range(const MatrixIndexT o, const MatrixIndexT l) { + return SubVector(*this, o, l); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L97-L105 + const SubVector Range(const MatrixIndexT o, const MatrixIndexT l) + const { + return SubVector(*this, o, l); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L107-L108 + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L226-L233 + void CopyFromVec(const VectorBase& v) { + TORCH_INTERNAL_ASSERT(tensor_.sizes() == v.tensor_.sizes()); + tensor_.copy_(v.tensor_); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L137-L139 + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L816-L832 + void ApplyFloor(Real floor_val, MatrixIndexT* floored_count = nullptr) { + auto index = tensor_ < floor_val; + auto tmp = tensor_.index_put_({index}, floor_val); + if (floored_count) { + *floored_count = index.sum().item().template to(); + } + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L164-L165 + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L449-L479 + void ApplyPow(Real power) { + tensor_.pow_(power); + TORCH_INTERNAL_ASSERT(!tensor_.isnan().sum().item().template to()); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L181-L184 + template + void AddVec(const Real alpha, const VectorBase& v) { + tensor_ += alpha * v.tensor_; + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L186-L187 + void AddVec2(const Real alpha, const VectorBase& v) { + tensor_ += alpha * (v.tensor_.square()); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L196-L198 + void AddMatVec( + const Real alpha, + const MatrixBase& M, + const MatrixTransposeType trans, + const VectorBase& v, + const Real beta) { // **beta previously defaulted to 0.0** + auto mat = M.tensor_; + if (trans == kTrans) { + mat = mat.transpose(1, 0); + } + tensor_.addmv_(mat, v.tensor_, beta, alpha); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L221-L222 + void MulElements(const VectorBase& v) { + tensor_ *= v.tensor_; + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L233-L234 + void Add(Real c) { + tensor_ += c; + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L236-L239 + void AddVecVec( + Real alpha, + const VectorBase& v, + const VectorBase& r, + Real beta) { + tensor_ = beta * tensor_ + alpha * v.tensor_ * r.tensor_; + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L246-L247 + void Scale(Real alpha) { + tensor_ *= alpha; + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L305-L306 + Real Min() const { + if (tensor_.numel()) { + return tensor_.min().item().template to(); + } + return std::numeric_limits::infinity(); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L308-L310 + Real Min(MatrixIndexT* index) const { + TORCH_INTERNAL_ASSERT(tensor_.numel()); + torch::Tensor value, ind; + std::tie(value, ind) = tensor_.min(0); + *index = ind.item().to(); + return value.item().to(); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L312-L313 + Real Sum() const { + return tensor_.sum().item().template to(); + }; + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L320-L321 + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L718-L736 + void AddRowSumMat(Real alpha, const MatrixBase& M, Real beta = 1.0) { + Vector ones(M.NumRows()); + ones.Set(1.0); + this->AddMatVec(alpha, M, kTrans, ones, beta); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L323-L324 + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L738-L757 + void AddColSumMat(Real alpha, const MatrixBase& M, Real beta = 1.0) { + Vector ones(M.NumCols()); + ones.Set(1.0); + this->AddMatVec(alpha, M, kNoTrans, ones, beta); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L326-L330 + void AddDiagMat2( + Real alpha, + const MatrixBase& M, + MatrixTransposeType trans = kNoTrans, + Real beta = 1.0) { + auto mat = M.tensor_; + if (trans == kNoTrans) { + tensor_ = + beta * tensor_ + torch::diag(torch::mm(mat, mat.transpose(1, 0))); + } else { + tensor_ = + beta * tensor_ + torch::diag(torch::mm(mat.transpose(1, 0), mat)); + } + } + + protected: + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L362-L365 + explicit VectorBase(); + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L378-L379 + Real* data_; + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L382 + KALDI_DISALLOW_COPY_AND_ASSIGN(VectorBase); +}; + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L385-L390 +template +class Vector : public VectorBase { + public: + //////////////////////////////////////////////////////////////////////////////// + // PyTorch-compatibility things + //////////////////////////////////////////////////////////////////////////////// + /// Construct VectorBase which is an interface to an existing torch::Tensor + /// object. + Vector(torch::Tensor tensor) : VectorBase(tensor){}; + + //////////////////////////////////////////////////////////////////////////////// + // Kaldi-compatible methods + //////////////////////////////////////////////////////////////////////////////// + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L392-L393 + Vector() : VectorBase(){}; + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L395-L399 + explicit Vector(const MatrixIndexT s, MatrixResizeType resize_type = kSetZero) + : VectorBase() { + Resize(s, resize_type); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L406-L410 + // Note: unlike the original implementation, this is "explicit". + explicit Vector(const Vector& v) + : VectorBase(v.tensor_.clone()) {} + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L412-L416 + explicit Vector(const VectorBase& v) + : VectorBase(v.tensor_.clone()) {} + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L434-L435 + void Swap(Vector* other) { + auto tmp = VectorBase::tensor_; + this->tensor_ = other->tensor_; + other->tensor_ = tmp; + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L444-L451 + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L189-L223 + void Resize(MatrixIndexT length, MatrixResizeType resize_type = kSetZero) { + auto& tensor_ = this->tensor_; + switch (resize_type) { + case kSetZero: + tensor_.resize_({length}).zero_(); + break; + case kUndefined: + tensor_.resize_({length}); + break; + case kCopyData: + auto tmp = tensor_; + auto tmp_numel = tensor_.numel(); + tensor_.resize_({length}).zero_(); + auto numel = Slice(length < tmp_numel ? length : tmp_numel); + tensor_.index_put_({numel}, tmp.index({numel})); + break; + } + // data_ptr() causes compiler error + this->data_ = static_cast(tensor_.data_ptr()); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L463-L468 + Vector& operator=(const VectorBase& other) { + Resize(other.Dim(), kUndefined); + this->CopyFromVec(other); + return *this; + } +}; + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L482-L485 +template +class SubVector : public VectorBase { + public: + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L487-L499 + SubVector( + const VectorBase& t, + const MatrixIndexT origin, + const MatrixIndexT length) + : VectorBase(t.tensor_.index({Slice(origin, origin + length)})) {} + + SubVector(kaldi::SubVector&& v): VectorBase(v.tensor_) { + this->data_ = v.data_; + // v.tensor_ = + v.data_ = NULL; + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L524-L528 + SubVector(const MatrixBase& matrix, MatrixIndexT row) + : VectorBase(matrix.tensor_.index({row})) {} +}; + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L540-L543 +template +std::ostream& operator<<(std::ostream& out, const VectorBase& v) { + out << v.tensor_; + return out; +} + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L573-L575 +template +Real VecVec(const VectorBase& v1, const VectorBase& v2) { + return torch::dot(v1.tensor_, v2.tensor_).item().template to(); +} + +} // namespace kaldi + +#endif diff --git a/third_party/sox/CMakeLists.txt b/third_party/sox/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..124d605e8e00ca18d62a3c5d12a9c26ac7e56940 --- /dev/null +++ b/third_party/sox/CMakeLists.txt @@ -0,0 +1,216 @@ +include(ExternalProject) + +set(INSTALL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../install) +set(ARCHIVE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/archives) +set(COMMON_ARGS --quiet --disable-shared --enable-static --prefix=${INSTALL_DIR} --with-pic --disable-dependency-tracking --disable-debug --disable-examples --disable-doc) + +# To pass custom environment variables to ExternalProject_Add command, +# we need to do `${CMAKE_COMMAND} -E env ${envs} `. +# https://stackoverflow.com/a/62437353 +# We constrcut the custom environment variables here +set(envs + "PKG_CONFIG_PATH=${INSTALL_DIR}/lib/pkgconfig" + "LDFLAGS=-L${INSTALL_DIR}/lib $ENV{LDFLAGS}" + "CFLAGS=-I${INSTALL_DIR}/include -fvisibility=hidden $ENV{CFLAGS}" +) + +ExternalProject_Add(mad + PREFIX ${CMAKE_CURRENT_BINARY_DIR} + DOWNLOAD_DIR ${ARCHIVE_DIR} + URL https://downloads.sourceforge.net/project/mad/libmad/0.15.1b/libmad-0.15.1b.tar.gz + URL_HASH SHA256=bbfac3ed6bfbc2823d3775ebb931087371e142bb0e9bb1bee51a76a6e0078690 + PATCH_COMMAND patch < ${CMAKE_CURRENT_SOURCE_DIR}/patch/libmad.patch && cp ${CMAKE_CURRENT_SOURCE_DIR}/patch/config.guess ${CMAKE_CURRENT_BINARY_DIR}/src/mad/config.guess && cp ${CMAKE_CURRENT_SOURCE_DIR}/patch/config.sub ${CMAKE_CURRENT_BINARY_DIR}/src/mad/config.sub + CONFIGURE_COMMAND ${CMAKE_COMMAND} -E env ${envs} ${CMAKE_CURRENT_BINARY_DIR}/src/mad/configure ${COMMON_ARGS} + DOWNLOAD_NO_PROGRESS ON + LOG_DOWNLOAD ON + LOG_UPDATE ON + LOG_CONFIGURE ON + LOG_BUILD ON + LOG_INSTALL ON + LOG_MERGED_STDOUTERR ON + LOG_OUTPUT_ON_FAILURE ON +) + +ExternalProject_Add(amr + PREFIX ${CMAKE_CURRENT_BINARY_DIR} + DOWNLOAD_DIR ${ARCHIVE_DIR} + URL https://sourceforge.net/projects/opencore-amr/files/opencore-amr/opencore-amr-0.1.5.tar.gz + URL_HASH SHA256=2c006cb9d5f651bfb5e60156dbff6af3c9d35c7bbcc9015308c0aff1e14cd341 + CONFIGURE_COMMAND ${CMAKE_COMMAND} -E env ${envs} ${CMAKE_CURRENT_BINARY_DIR}/src/amr/configure ${COMMON_ARGS} + DOWNLOAD_NO_PROGRESS ON + LOG_DOWNLOAD ON + LOG_UPDATE ON + LOG_CONFIGURE ON + LOG_BUILD ON + LOG_INSTALL ON + LOG_MERGED_STDOUTERR ON + LOG_OUTPUT_ON_FAILURE ON +) + +ExternalProject_Add(lame + PREFIX ${CMAKE_CURRENT_BINARY_DIR} + DOWNLOAD_DIR ${ARCHIVE_DIR} + URL https://downloads.sourceforge.net/project/lame/lame/3.99/lame-3.99.5.tar.gz + URL_HASH SHA256=24346b4158e4af3bd9f2e194bb23eb473c75fb7377011523353196b19b9a23ff + CONFIGURE_COMMAND ${CMAKE_COMMAND} -E env ${envs} ${CMAKE_CURRENT_BINARY_DIR}/src/lame/configure ${COMMON_ARGS} --enable-nasm + PATCH_COMMAND cp ${CMAKE_CURRENT_SOURCE_DIR}/patch/config.guess ${CMAKE_CURRENT_BINARY_DIR}/src/lame/config.guess && cp ${CMAKE_CURRENT_SOURCE_DIR}/patch/config.sub ${CMAKE_CURRENT_BINARY_DIR}/src/lame/config.sub + DOWNLOAD_NO_PROGRESS ON + LOG_DOWNLOAD ON + LOG_UPDATE ON + LOG_CONFIGURE ON + LOG_BUILD ON + LOG_INSTALL ON + LOG_MERGED_STDOUTERR ON + LOG_OUTPUT_ON_FAILURE ON +) + +ExternalProject_Add(ogg + PREFIX ${CMAKE_CURRENT_BINARY_DIR} + DOWNLOAD_DIR ${ARCHIVE_DIR} + URL https://ftp.osuosl.org/pub/xiph/releases/ogg/libogg-1.3.3.tar.gz + URL_HASH SHA256=c2e8a485110b97550f453226ec644ebac6cb29d1caef2902c007edab4308d985 + CONFIGURE_COMMAND ${CMAKE_COMMAND} -E env ${envs} ${CMAKE_CURRENT_BINARY_DIR}/src/ogg/configure ${COMMON_ARGS} + DOWNLOAD_NO_PROGRESS ON + LOG_DOWNLOAD ON + LOG_UPDATE ON + LOG_CONFIGURE ON + LOG_BUILD ON + LOG_INSTALL ON + LOG_MERGED_STDOUTERR ON + LOG_OUTPUT_ON_FAILURE ON +) + +ExternalProject_Add(flac + PREFIX ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ogg + DOWNLOAD_DIR ${ARCHIVE_DIR} + URL https://ftp.osuosl.org/pub/xiph/releases/flac/flac-1.3.2.tar.xz + URL_HASH SHA256=91cfc3ed61dc40f47f050a109b08610667d73477af6ef36dcad31c31a4a8d53f + CONFIGURE_COMMAND ${CMAKE_COMMAND} -E env ${envs} ${CMAKE_CURRENT_BINARY_DIR}/src/flac/configure ${COMMON_ARGS} --with-ogg --disable-cpplibs + DOWNLOAD_NO_PROGRESS ON + LOG_DOWNLOAD ON + LOG_UPDATE ON + LOG_CONFIGURE ON + LOG_BUILD ON + LOG_INSTALL ON + LOG_MERGED_STDOUTERR ON + LOG_OUTPUT_ON_FAILURE ON +) + +ExternalProject_Add(vorbis + PREFIX ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ogg + DOWNLOAD_DIR ${ARCHIVE_DIR} + URL https://ftp.osuosl.org/pub/xiph/releases/vorbis/libvorbis-1.3.6.tar.gz + URL_HASH SHA256=6ed40e0241089a42c48604dc00e362beee00036af2d8b3f46338031c9e0351cb + CONFIGURE_COMMAND ${CMAKE_COMMAND} -E env ${envs} ${CMAKE_CURRENT_BINARY_DIR}/src/vorbis/configure ${COMMON_ARGS} --with-ogg + DOWNLOAD_NO_PROGRESS ON + LOG_DOWNLOAD ON + LOG_UPDATE ON + LOG_CONFIGURE ON + LOG_BUILD ON + LOG_INSTALL ON + LOG_MERGED_STDOUTERR ON + LOG_OUTPUT_ON_FAILURE ON +) + +ExternalProject_Add(opus + PREFIX ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ogg + DOWNLOAD_DIR ${ARCHIVE_DIR} + URL https://ftp.osuosl.org/pub/xiph/releases/opus/opus-1.3.1.tar.gz + URL_HASH SHA256=65b58e1e25b2a114157014736a3d9dfeaad8d41be1c8179866f144a2fb44ff9d + CONFIGURE_COMMAND ${CMAKE_COMMAND} -E env ${envs} ${CMAKE_CURRENT_BINARY_DIR}/src/opus/configure ${COMMON_ARGS} --with-ogg + DOWNLOAD_NO_PROGRESS ON + LOG_DOWNLOAD ON + LOG_UPDATE ON + LOG_CONFIGURE ON + LOG_BUILD ON + LOG_INSTALL ON + LOG_MERGED_STDOUTERR ON + LOG_OUTPUT_ON_FAILURE ON +) + +ExternalProject_Add(opusfile + PREFIX ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS opus + DOWNLOAD_DIR ${ARCHIVE_DIR} + URL https://ftp.osuosl.org/pub/xiph/releases/opus/opusfile-0.12.tar.gz + URL_HASH SHA256=118d8601c12dd6a44f52423e68ca9083cc9f2bfe72da7a8c1acb22a80ae3550b + CONFIGURE_COMMAND ${CMAKE_COMMAND} -E env ${envs} ${CMAKE_CURRENT_BINARY_DIR}/src/opusfile/configure ${COMMON_ARGS} --disable-http + DOWNLOAD_NO_PROGRESS ON + LOG_DOWNLOAD ON + LOG_UPDATE ON + LOG_CONFIGURE ON + LOG_BUILD ON + LOG_INSTALL ON + LOG_MERGED_STDOUTERR ON + LOG_OUTPUT_ON_FAILURE ON +) + +# OpenMP is by default compiled against GNU OpenMP, which conflicts with the version of OpenMP that PyTorch uses. +# See https://github.com/pytorch/audio/pull/1026 +# TODO: Add flags like https://github.com/suphoff/pytorch_parallel_extension_cpp/blob/master/setup.py +set(SOX_OPTIONS + --disable-openmp + --with-amrnb + --with-amrwb + --with-flac + --with-lame + --with-mad + --with-oggvorbis + --with-opus + --without-alsa + --without-ao + --without-coreaudio + --without-oss + --without-id3tag + --without-ladspa + --without-magic + --without-png + --without-pulseaudio + --without-sndfile + --without-sndio + --without-sunaudio + --without-waveaudio + --without-wavpack + --without-twolame + ) + +set(SOX_LIBRARIES + ${INSTALL_DIR}/lib/libsox.a + ${INSTALL_DIR}/lib/libopencore-amrnb.a + ${INSTALL_DIR}/lib/libopencore-amrwb.a + ${INSTALL_DIR}/lib/libmad.a + ${INSTALL_DIR}/lib/libmp3lame.a + ${INSTALL_DIR}/lib/libFLAC.a + ${INSTALL_DIR}/lib/libopusfile.a + ${INSTALL_DIR}/lib/libopus.a + ${INSTALL_DIR}/lib/libvorbisenc.a + ${INSTALL_DIR}/lib/libvorbisfile.a + ${INSTALL_DIR}/lib/libvorbis.a + ${INSTALL_DIR}/lib/libogg.a + ) + +ExternalProject_Add(sox + PREFIX ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ogg flac vorbis opusfile lame mad amr + DOWNLOAD_DIR ${ARCHIVE_DIR} + URL https://downloads.sourceforge.net/project/sox/sox/14.4.2/sox-14.4.2.tar.bz2 + URL_HASH SHA256=81a6956d4330e75b5827316e44ae381e6f1e8928003c6aa45896da9041ea149c + CONFIGURE_COMMAND ${CMAKE_COMMAND} -E env ${envs} ${CMAKE_CURRENT_BINARY_DIR}/src/sox/configure ${COMMON_ARGS} ${SOX_OPTIONS} + PATCH_COMMAND patch -p1 < ${CMAKE_CURRENT_SOURCE_DIR}/patch/sox.patch + BUILD_BYPRODUCTS ${SOX_LIBRARIES} + DOWNLOAD_NO_PROGRESS ON + LOG_DOWNLOAD ON + LOG_UPDATE ON + LOG_CONFIGURE ON + LOG_BUILD ON + LOG_INSTALL ON + LOG_MERGED_STDOUTERR ON + LOG_OUTPUT_ON_FAILURE ON +) + +add_dependencies(libsox sox) +set(SOX_INCLUDE_DIR ${INSTALL_DIR}/include PARENT_SCOPE) +set(SOX_LIBRARIES ${SOX_LIBRARIES} PARENT_SCOPE) diff --git a/third_party/sox/patch/config.guess b/third_party/sox/patch/config.guess new file mode 100644 index 0000000000000000000000000000000000000000..dc0a6b29976a9990ef412d3c5fd696ac2641ae0b --- /dev/null +++ b/third_party/sox/patch/config.guess @@ -0,0 +1,1702 @@ +#! /bin/sh +# Attempt to guess a canonical system name. +# Copyright 1992-2021 Free Software Foundation, Inc. + +timestamp='2021-05-24' + +# This file is free software; you can redistribute it and/or modify it +# under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, see . +# +# As a special exception to the GNU General Public License, if you +# distribute this file as part of a program that contains a +# configuration script generated by Autoconf, you may include it under +# the same distribution terms that you use for the rest of that +# program. This Exception is an additional permission under section 7 +# of the GNU General Public License, version 3 ("GPLv3"). +# +# Originally written by Per Bothner; maintained since 2000 by Ben Elliston. +# +# You can get the latest version of this script from: +# https://git.savannah.gnu.org/cgit/config.git/plain/config.guess +# +# Please send patches to . + + +me=$(echo "$0" | sed -e 's,.*/,,') + +usage="\ +Usage: $0 [OPTION] + +Output the configuration name of the system \`$me' is run on. + +Options: + -h, --help print this help, then exit + -t, --time-stamp print date of last modification, then exit + -v, --version print version number, then exit + +Report bugs and patches to ." + +version="\ +GNU config.guess ($timestamp) + +Originally written by Per Bothner. +Copyright 1992-2021 Free Software Foundation, Inc. + +This is free software; see the source for copying conditions. There is NO +warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE." + +help=" +Try \`$me --help' for more information." + +# Parse command line +while test $# -gt 0 ; do + case $1 in + --time-stamp | --time* | -t ) + echo "$timestamp" ; exit ;; + --version | -v ) + echo "$version" ; exit ;; + --help | --h* | -h ) + echo "$usage"; exit ;; + -- ) # Stop option processing + shift; break ;; + - ) # Use stdin as input. + break ;; + -* ) + echo "$me: invalid option $1$help" >&2 + exit 1 ;; + * ) + break ;; + esac +done + +if test $# != 0; then + echo "$me: too many arguments$help" >&2 + exit 1 +fi + +# CC_FOR_BUILD -- compiler used by this script. Note that the use of a +# compiler to aid in system detection is discouraged as it requires +# temporary files to be created and, as you can see below, it is a +# headache to deal with in a portable fashion. + +# Historically, `CC_FOR_BUILD' used to be named `HOST_CC'. We still +# use `HOST_CC' if defined, but it is deprecated. + +# Portable tmp directory creation inspired by the Autoconf team. + +tmp= +# shellcheck disable=SC2172 +trap 'test -z "$tmp" || rm -fr "$tmp"' 0 1 2 13 15 + +set_cc_for_build() { + # prevent multiple calls if $tmp is already set + test "$tmp" && return 0 + : "${TMPDIR=/tmp}" + # shellcheck disable=SC2039 + { tmp=$( (umask 077 && mktemp -d "$TMPDIR/cgXXXXXX") 2>/dev/null) && test -n "$tmp" && test -d "$tmp" ; } || + { test -n "$RANDOM" && tmp=$TMPDIR/cg$$-$RANDOM && (umask 077 && mkdir "$tmp" 2>/dev/null) ; } || + { tmp=$TMPDIR/cg-$$ && (umask 077 && mkdir "$tmp" 2>/dev/null) && echo "Warning: creating insecure temp directory" >&2 ; } || + { echo "$me: cannot create a temporary directory in $TMPDIR" >&2 ; exit 1 ; } + dummy=$tmp/dummy + case ${CC_FOR_BUILD-},${HOST_CC-},${CC-} in + ,,) echo "int x;" > "$dummy.c" + for driver in cc gcc c89 c99 ; do + if ($driver -c -o "$dummy.o" "$dummy.c") >/dev/null 2>&1 ; then + CC_FOR_BUILD="$driver" + break + fi + done + if test x"$CC_FOR_BUILD" = x ; then + CC_FOR_BUILD=no_compiler_found + fi + ;; + ,,*) CC_FOR_BUILD=$CC ;; + ,*,*) CC_FOR_BUILD=$HOST_CC ;; + esac +} + +# This is needed to find uname on a Pyramid OSx when run in the BSD universe. +# (ghazi@noc.rutgers.edu 1994-08-24) +if test -f /.attbin/uname ; then + PATH=$PATH:/.attbin ; export PATH +fi + +UNAME_MACHINE=$( (uname -m) 2>/dev/null) || UNAME_MACHINE=unknown +UNAME_RELEASE=$( (uname -r) 2>/dev/null) || UNAME_RELEASE=unknown +UNAME_SYSTEM=$( (uname -s) 2>/dev/null) || UNAME_SYSTEM=unknown +UNAME_VERSION=$( (uname -v) 2>/dev/null) || UNAME_VERSION=unknown + +case $UNAME_SYSTEM in +Linux|GNU|GNU/*) + LIBC=unknown + + set_cc_for_build + cat <<-EOF > "$dummy.c" + #include + #if defined(__UCLIBC__) + LIBC=uclibc + #elif defined(__dietlibc__) + LIBC=dietlibc + #elif defined(__GLIBC__) + LIBC=gnu + #else + #include + /* First heuristic to detect musl libc. */ + #ifdef __DEFINED_va_list + LIBC=musl + #endif + #endif + EOF + eval "$($CC_FOR_BUILD -E "$dummy.c" 2>/dev/null | grep '^LIBC' | sed 's, ,,g')" + + # Second heuristic to detect musl libc. + if [ "$LIBC" = unknown ] && + command -v ldd >/dev/null && + ldd --version 2>&1 | grep -q ^musl; then + LIBC=musl + fi + + # If the system lacks a compiler, then just pick glibc. + # We could probably try harder. + if [ "$LIBC" = unknown ]; then + LIBC=gnu + fi + ;; +esac + +# Note: order is significant - the case branches are not exclusive. + +case $UNAME_MACHINE:$UNAME_SYSTEM:$UNAME_RELEASE:$UNAME_VERSION in + *:NetBSD:*:*) + # NetBSD (nbsd) targets should (where applicable) match one or + # more of the tuples: *-*-netbsdelf*, *-*-netbsdaout*, + # *-*-netbsdecoff* and *-*-netbsd*. For targets that recently + # switched to ELF, *-*-netbsd* would select the old + # object file format. This provides both forward + # compatibility and a consistent mechanism for selecting the + # object file format. + # + # Note: NetBSD doesn't particularly care about the vendor + # portion of the name. We always set it to "unknown". + UNAME_MACHINE_ARCH=$( (uname -p 2>/dev/null || \ + /sbin/sysctl -n hw.machine_arch 2>/dev/null || \ + /usr/sbin/sysctl -n hw.machine_arch 2>/dev/null || \ + echo unknown)) + case $UNAME_MACHINE_ARCH in + aarch64eb) machine=aarch64_be-unknown ;; + armeb) machine=armeb-unknown ;; + arm*) machine=arm-unknown ;; + sh3el) machine=shl-unknown ;; + sh3eb) machine=sh-unknown ;; + sh5el) machine=sh5le-unknown ;; + earmv*) + arch=$(echo "$UNAME_MACHINE_ARCH" | sed -e 's,^e\(armv[0-9]\).*$,\1,') + endian=$(echo "$UNAME_MACHINE_ARCH" | sed -ne 's,^.*\(eb\)$,\1,p') + machine="${arch}${endian}"-unknown + ;; + *) machine="$UNAME_MACHINE_ARCH"-unknown ;; + esac + # The Operating System including object format, if it has switched + # to ELF recently (or will in the future) and ABI. + case $UNAME_MACHINE_ARCH in + earm*) + os=netbsdelf + ;; + arm*|i386|m68k|ns32k|sh3*|sparc|vax) + set_cc_for_build + if echo __ELF__ | $CC_FOR_BUILD -E - 2>/dev/null \ + | grep -q __ELF__ + then + # Once all utilities can be ECOFF (netbsdecoff) or a.out (netbsdaout). + # Return netbsd for either. FIX? + os=netbsd + else + os=netbsdelf + fi + ;; + *) + os=netbsd + ;; + esac + # Determine ABI tags. + case $UNAME_MACHINE_ARCH in + earm*) + expr='s/^earmv[0-9]/-eabi/;s/eb$//' + abi=$(echo "$UNAME_MACHINE_ARCH" | sed -e "$expr") + ;; + esac + # The OS release + # Debian GNU/NetBSD machines have a different userland, and + # thus, need a distinct triplet. However, they do not need + # kernel version information, so it can be replaced with a + # suitable tag, in the style of linux-gnu. + case $UNAME_VERSION in + Debian*) + release='-gnu' + ;; + *) + release=$(echo "$UNAME_RELEASE" | sed -e 's/[-_].*//' | cut -d. -f1,2) + ;; + esac + # Since CPU_TYPE-MANUFACTURER-KERNEL-OPERATING_SYSTEM: + # contains redundant information, the shorter form: + # CPU_TYPE-MANUFACTURER-OPERATING_SYSTEM is used. + echo "$machine-${os}${release}${abi-}" + exit ;; + *:Bitrig:*:*) + UNAME_MACHINE_ARCH=$(arch | sed 's/Bitrig.//') + echo "$UNAME_MACHINE_ARCH"-unknown-bitrig"$UNAME_RELEASE" + exit ;; + *:OpenBSD:*:*) + UNAME_MACHINE_ARCH=$(arch | sed 's/OpenBSD.//') + echo "$UNAME_MACHINE_ARCH"-unknown-openbsd"$UNAME_RELEASE" + exit ;; + *:SecBSD:*:*) + UNAME_MACHINE_ARCH=$(arch | sed 's/SecBSD.//') + echo "$UNAME_MACHINE_ARCH"-unknown-secbsd"$UNAME_RELEASE" + exit ;; + *:LibertyBSD:*:*) + UNAME_MACHINE_ARCH=$(arch | sed 's/^.*BSD\.//') + echo "$UNAME_MACHINE_ARCH"-unknown-libertybsd"$UNAME_RELEASE" + exit ;; + *:MidnightBSD:*:*) + echo "$UNAME_MACHINE"-unknown-midnightbsd"$UNAME_RELEASE" + exit ;; + *:ekkoBSD:*:*) + echo "$UNAME_MACHINE"-unknown-ekkobsd"$UNAME_RELEASE" + exit ;; + *:SolidBSD:*:*) + echo "$UNAME_MACHINE"-unknown-solidbsd"$UNAME_RELEASE" + exit ;; + *:OS108:*:*) + echo "$UNAME_MACHINE"-unknown-os108_"$UNAME_RELEASE" + exit ;; + macppc:MirBSD:*:*) + echo powerpc-unknown-mirbsd"$UNAME_RELEASE" + exit ;; + *:MirBSD:*:*) + echo "$UNAME_MACHINE"-unknown-mirbsd"$UNAME_RELEASE" + exit ;; + *:Sortix:*:*) + echo "$UNAME_MACHINE"-unknown-sortix + exit ;; + *:Twizzler:*:*) + echo "$UNAME_MACHINE"-unknown-twizzler + exit ;; + *:Redox:*:*) + echo "$UNAME_MACHINE"-unknown-redox + exit ;; + mips:OSF1:*.*) + echo mips-dec-osf1 + exit ;; + alpha:OSF1:*:*) + # Reset EXIT trap before exiting to avoid spurious non-zero exit code. + trap '' 0 + case $UNAME_RELEASE in + *4.0) + UNAME_RELEASE=$(/usr/sbin/sizer -v | awk '{print $3}') + ;; + *5.*) + UNAME_RELEASE=$(/usr/sbin/sizer -v | awk '{print $4}') + ;; + esac + # According to Compaq, /usr/sbin/psrinfo has been available on + # OSF/1 and Tru64 systems produced since 1995. I hope that + # covers most systems running today. This code pipes the CPU + # types through head -n 1, so we only detect the type of CPU 0. + ALPHA_CPU_TYPE=$(/usr/sbin/psrinfo -v | sed -n -e 's/^ The alpha \(.*\) processor.*$/\1/p' | head -n 1) + case $ALPHA_CPU_TYPE in + "EV4 (21064)") + UNAME_MACHINE=alpha ;; + "EV4.5 (21064)") + UNAME_MACHINE=alpha ;; + "LCA4 (21066/21068)") + UNAME_MACHINE=alpha ;; + "EV5 (21164)") + UNAME_MACHINE=alphaev5 ;; + "EV5.6 (21164A)") + UNAME_MACHINE=alphaev56 ;; + "EV5.6 (21164PC)") + UNAME_MACHINE=alphapca56 ;; + "EV5.7 (21164PC)") + UNAME_MACHINE=alphapca57 ;; + "EV6 (21264)") + UNAME_MACHINE=alphaev6 ;; + "EV6.7 (21264A)") + UNAME_MACHINE=alphaev67 ;; + "EV6.8CB (21264C)") + UNAME_MACHINE=alphaev68 ;; + "EV6.8AL (21264B)") + UNAME_MACHINE=alphaev68 ;; + "EV6.8CX (21264D)") + UNAME_MACHINE=alphaev68 ;; + "EV6.9A (21264/EV69A)") + UNAME_MACHINE=alphaev69 ;; + "EV7 (21364)") + UNAME_MACHINE=alphaev7 ;; + "EV7.9 (21364A)") + UNAME_MACHINE=alphaev79 ;; + esac + # A Pn.n version is a patched version. + # A Vn.n version is a released version. + # A Tn.n version is a released field test version. + # A Xn.n version is an unreleased experimental baselevel. + # 1.2 uses "1.2" for uname -r. + echo "$UNAME_MACHINE"-dec-osf"$(echo "$UNAME_RELEASE" | sed -e 's/^[PVTX]//' | tr ABCDEFGHIJKLMNOPQRSTUVWXYZ abcdefghijklmnopqrstuvwxyz)" + exit ;; + Amiga*:UNIX_System_V:4.0:*) + echo m68k-unknown-sysv4 + exit ;; + *:[Aa]miga[Oo][Ss]:*:*) + echo "$UNAME_MACHINE"-unknown-amigaos + exit ;; + *:[Mm]orph[Oo][Ss]:*:*) + echo "$UNAME_MACHINE"-unknown-morphos + exit ;; + *:OS/390:*:*) + echo i370-ibm-openedition + exit ;; + *:z/VM:*:*) + echo s390-ibm-zvmoe + exit ;; + *:OS400:*:*) + echo powerpc-ibm-os400 + exit ;; + arm:RISC*:1.[012]*:*|arm:riscix:1.[012]*:*) + echo arm-acorn-riscix"$UNAME_RELEASE" + exit ;; + arm*:riscos:*:*|arm*:RISCOS:*:*) + echo arm-unknown-riscos + exit ;; + SR2?01:HI-UX/MPP:*:* | SR8000:HI-UX/MPP:*:*) + echo hppa1.1-hitachi-hiuxmpp + exit ;; + Pyramid*:OSx*:*:* | MIS*:OSx*:*:* | MIS*:SMP_DC-OSx*:*:*) + # akee@wpdis03.wpafb.af.mil (Earle F. Ake) contributed MIS and NILE. + if test "$( (/bin/universe) 2>/dev/null)" = att ; then + echo pyramid-pyramid-sysv3 + else + echo pyramid-pyramid-bsd + fi + exit ;; + NILE*:*:*:dcosx) + echo pyramid-pyramid-svr4 + exit ;; + DRS?6000:unix:4.0:6*) + echo sparc-icl-nx6 + exit ;; + DRS?6000:UNIX_SV:4.2*:7* | DRS?6000:isis:4.2*:7*) + case $(/usr/bin/uname -p) in + sparc) echo sparc-icl-nx7; exit ;; + esac ;; + s390x:SunOS:*:*) + echo "$UNAME_MACHINE"-ibm-solaris2"$(echo "$UNAME_RELEASE" | sed -e 's/[^.]*//')" + exit ;; + sun4H:SunOS:5.*:*) + echo sparc-hal-solaris2"$(echo "$UNAME_RELEASE"|sed -e 's/[^.]*//')" + exit ;; + sun4*:SunOS:5.*:* | tadpole*:SunOS:5.*:*) + echo sparc-sun-solaris2"$(echo "$UNAME_RELEASE" | sed -e 's/[^.]*//')" + exit ;; + i86pc:AuroraUX:5.*:* | i86xen:AuroraUX:5.*:*) + echo i386-pc-auroraux"$UNAME_RELEASE" + exit ;; + i86pc:SunOS:5.*:* | i86xen:SunOS:5.*:*) + set_cc_for_build + SUN_ARCH=i386 + # If there is a compiler, see if it is configured for 64-bit objects. + # Note that the Sun cc does not turn __LP64__ into 1 like gcc does. + # This test works for both compilers. + if test "$CC_FOR_BUILD" != no_compiler_found; then + if (echo '#ifdef __amd64'; echo IS_64BIT_ARCH; echo '#endif') | \ + (CCOPTS="" $CC_FOR_BUILD -E - 2>/dev/null) | \ + grep IS_64BIT_ARCH >/dev/null + then + SUN_ARCH=x86_64 + fi + fi + echo "$SUN_ARCH"-pc-solaris2"$(echo "$UNAME_RELEASE"|sed -e 's/[^.]*//')" + exit ;; + sun4*:SunOS:6*:*) + # According to config.sub, this is the proper way to canonicalize + # SunOS6. Hard to guess exactly what SunOS6 will be like, but + # it's likely to be more like Solaris than SunOS4. + echo sparc-sun-solaris3"$(echo "$UNAME_RELEASE"|sed -e 's/[^.]*//')" + exit ;; + sun4*:SunOS:*:*) + case $(/usr/bin/arch -k) in + Series*|S4*) + UNAME_RELEASE=$(uname -v) + ;; + esac + # Japanese Language versions have a version number like `4.1.3-JL'. + echo sparc-sun-sunos"$(echo "$UNAME_RELEASE"|sed -e 's/-/_/')" + exit ;; + sun3*:SunOS:*:*) + echo m68k-sun-sunos"$UNAME_RELEASE" + exit ;; + sun*:*:4.2BSD:*) + UNAME_RELEASE=$( (sed 1q /etc/motd | awk '{print substr($5,1,3)}') 2>/dev/null) + test "x$UNAME_RELEASE" = x && UNAME_RELEASE=3 + case $(/bin/arch) in + sun3) + echo m68k-sun-sunos"$UNAME_RELEASE" + ;; + sun4) + echo sparc-sun-sunos"$UNAME_RELEASE" + ;; + esac + exit ;; + aushp:SunOS:*:*) + echo sparc-auspex-sunos"$UNAME_RELEASE" + exit ;; + # The situation for MiNT is a little confusing. The machine name + # can be virtually everything (everything which is not + # "atarist" or "atariste" at least should have a processor + # > m68000). The system name ranges from "MiNT" over "FreeMiNT" + # to the lowercase version "mint" (or "freemint"). Finally + # the system name "TOS" denotes a system which is actually not + # MiNT. But MiNT is downward compatible to TOS, so this should + # be no problem. + atarist[e]:*MiNT:*:* | atarist[e]:*mint:*:* | atarist[e]:*TOS:*:*) + echo m68k-atari-mint"$UNAME_RELEASE" + exit ;; + atari*:*MiNT:*:* | atari*:*mint:*:* | atarist[e]:*TOS:*:*) + echo m68k-atari-mint"$UNAME_RELEASE" + exit ;; + *falcon*:*MiNT:*:* | *falcon*:*mint:*:* | *falcon*:*TOS:*:*) + echo m68k-atari-mint"$UNAME_RELEASE" + exit ;; + milan*:*MiNT:*:* | milan*:*mint:*:* | *milan*:*TOS:*:*) + echo m68k-milan-mint"$UNAME_RELEASE" + exit ;; + hades*:*MiNT:*:* | hades*:*mint:*:* | *hades*:*TOS:*:*) + echo m68k-hades-mint"$UNAME_RELEASE" + exit ;; + *:*MiNT:*:* | *:*mint:*:* | *:*TOS:*:*) + echo m68k-unknown-mint"$UNAME_RELEASE" + exit ;; + m68k:machten:*:*) + echo m68k-apple-machten"$UNAME_RELEASE" + exit ;; + powerpc:machten:*:*) + echo powerpc-apple-machten"$UNAME_RELEASE" + exit ;; + RISC*:Mach:*:*) + echo mips-dec-mach_bsd4.3 + exit ;; + RISC*:ULTRIX:*:*) + echo mips-dec-ultrix"$UNAME_RELEASE" + exit ;; + VAX*:ULTRIX*:*:*) + echo vax-dec-ultrix"$UNAME_RELEASE" + exit ;; + 2020:CLIX:*:* | 2430:CLIX:*:*) + echo clipper-intergraph-clix"$UNAME_RELEASE" + exit ;; + mips:*:*:UMIPS | mips:*:*:RISCos) + set_cc_for_build + sed 's/^ //' << EOF > "$dummy.c" +#ifdef __cplusplus +#include /* for printf() prototype */ + int main (int argc, char *argv[]) { +#else + int main (argc, argv) int argc; char *argv[]; { +#endif + #if defined (host_mips) && defined (MIPSEB) + #if defined (SYSTYPE_SYSV) + printf ("mips-mips-riscos%ssysv\\n", argv[1]); exit (0); + #endif + #if defined (SYSTYPE_SVR4) + printf ("mips-mips-riscos%ssvr4\\n", argv[1]); exit (0); + #endif + #if defined (SYSTYPE_BSD43) || defined(SYSTYPE_BSD) + printf ("mips-mips-riscos%sbsd\\n", argv[1]); exit (0); + #endif + #endif + exit (-1); + } +EOF + $CC_FOR_BUILD -o "$dummy" "$dummy.c" && + dummyarg=$(echo "$UNAME_RELEASE" | sed -n 's/\([0-9]*\).*/\1/p') && + SYSTEM_NAME=$("$dummy" "$dummyarg") && + { echo "$SYSTEM_NAME"; exit; } + echo mips-mips-riscos"$UNAME_RELEASE" + exit ;; + Motorola:PowerMAX_OS:*:*) + echo powerpc-motorola-powermax + exit ;; + Motorola:*:4.3:PL8-*) + echo powerpc-harris-powermax + exit ;; + Night_Hawk:*:*:PowerMAX_OS | Synergy:PowerMAX_OS:*:*) + echo powerpc-harris-powermax + exit ;; + Night_Hawk:Power_UNIX:*:*) + echo powerpc-harris-powerunix + exit ;; + m88k:CX/UX:7*:*) + echo m88k-harris-cxux7 + exit ;; + m88k:*:4*:R4*) + echo m88k-motorola-sysv4 + exit ;; + m88k:*:3*:R3*) + echo m88k-motorola-sysv3 + exit ;; + AViiON:dgux:*:*) + # DG/UX returns AViiON for all architectures + UNAME_PROCESSOR=$(/usr/bin/uname -p) + if test "$UNAME_PROCESSOR" = mc88100 || test "$UNAME_PROCESSOR" = mc88110 + then + if test "$TARGET_BINARY_INTERFACE"x = m88kdguxelfx || \ + test "$TARGET_BINARY_INTERFACE"x = x + then + echo m88k-dg-dgux"$UNAME_RELEASE" + else + echo m88k-dg-dguxbcs"$UNAME_RELEASE" + fi + else + echo i586-dg-dgux"$UNAME_RELEASE" + fi + exit ;; + M88*:DolphinOS:*:*) # DolphinOS (SVR3) + echo m88k-dolphin-sysv3 + exit ;; + M88*:*:R3*:*) + # Delta 88k system running SVR3 + echo m88k-motorola-sysv3 + exit ;; + XD88*:*:*:*) # Tektronix XD88 system running UTekV (SVR3) + echo m88k-tektronix-sysv3 + exit ;; + Tek43[0-9][0-9]:UTek:*:*) # Tektronix 4300 system running UTek (BSD) + echo m68k-tektronix-bsd + exit ;; + *:IRIX*:*:*) + echo mips-sgi-irix"$(echo "$UNAME_RELEASE"|sed -e 's/-/_/g')" + exit ;; + ????????:AIX?:[12].1:2) # AIX 2.2.1 or AIX 2.1.1 is RT/PC AIX. + echo romp-ibm-aix # uname -m gives an 8 hex-code CPU id + exit ;; # Note that: echo "'$(uname -s)'" gives 'AIX ' + i*86:AIX:*:*) + echo i386-ibm-aix + exit ;; + ia64:AIX:*:*) + if test -x /usr/bin/oslevel ; then + IBM_REV=$(/usr/bin/oslevel) + else + IBM_REV="$UNAME_VERSION.$UNAME_RELEASE" + fi + echo "$UNAME_MACHINE"-ibm-aix"$IBM_REV" + exit ;; + *:AIX:2:3) + if grep bos325 /usr/include/stdio.h >/dev/null 2>&1; then + set_cc_for_build + sed 's/^ //' << EOF > "$dummy.c" + #include + + main() + { + if (!__power_pc()) + exit(1); + puts("powerpc-ibm-aix3.2.5"); + exit(0); + } +EOF + if $CC_FOR_BUILD -o "$dummy" "$dummy.c" && SYSTEM_NAME=$("$dummy") + then + echo "$SYSTEM_NAME" + else + echo rs6000-ibm-aix3.2.5 + fi + elif grep bos324 /usr/include/stdio.h >/dev/null 2>&1; then + echo rs6000-ibm-aix3.2.4 + else + echo rs6000-ibm-aix3.2 + fi + exit ;; + *:AIX:*:[4567]) + IBM_CPU_ID=$(/usr/sbin/lsdev -C -c processor -S available | sed 1q | awk '{ print $1 }') + if /usr/sbin/lsattr -El "$IBM_CPU_ID" | grep ' POWER' >/dev/null 2>&1; then + IBM_ARCH=rs6000 + else + IBM_ARCH=powerpc + fi + if test -x /usr/bin/lslpp ; then + IBM_REV=$(/usr/bin/lslpp -Lqc bos.rte.libc | + awk -F: '{ print $3 }' | sed s/[0-9]*$/0/) + else + IBM_REV="$UNAME_VERSION.$UNAME_RELEASE" + fi + echo "$IBM_ARCH"-ibm-aix"$IBM_REV" + exit ;; + *:AIX:*:*) + echo rs6000-ibm-aix + exit ;; + ibmrt:4.4BSD:*|romp-ibm:4.4BSD:*) + echo romp-ibm-bsd4.4 + exit ;; + ibmrt:*BSD:*|romp-ibm:BSD:*) # covers RT/PC BSD and + echo romp-ibm-bsd"$UNAME_RELEASE" # 4.3 with uname added to + exit ;; # report: romp-ibm BSD 4.3 + *:BOSX:*:*) + echo rs6000-bull-bosx + exit ;; + DPX/2?00:B.O.S.:*:*) + echo m68k-bull-sysv3 + exit ;; + 9000/[34]??:4.3bsd:1.*:*) + echo m68k-hp-bsd + exit ;; + hp300:4.4BSD:*:* | 9000/[34]??:4.3bsd:2.*:*) + echo m68k-hp-bsd4.4 + exit ;; + 9000/[34678]??:HP-UX:*:*) + HPUX_REV=$(echo "$UNAME_RELEASE"|sed -e 's/[^.]*.[0B]*//') + case $UNAME_MACHINE in + 9000/31?) HP_ARCH=m68000 ;; + 9000/[34]??) HP_ARCH=m68k ;; + 9000/[678][0-9][0-9]) + if test -x /usr/bin/getconf; then + sc_cpu_version=$(/usr/bin/getconf SC_CPU_VERSION 2>/dev/null) + sc_kernel_bits=$(/usr/bin/getconf SC_KERNEL_BITS 2>/dev/null) + case $sc_cpu_version in + 523) HP_ARCH=hppa1.0 ;; # CPU_PA_RISC1_0 + 528) HP_ARCH=hppa1.1 ;; # CPU_PA_RISC1_1 + 532) # CPU_PA_RISC2_0 + case $sc_kernel_bits in + 32) HP_ARCH=hppa2.0n ;; + 64) HP_ARCH=hppa2.0w ;; + '') HP_ARCH=hppa2.0 ;; # HP-UX 10.20 + esac ;; + esac + fi + if test "$HP_ARCH" = ""; then + set_cc_for_build + sed 's/^ //' << EOF > "$dummy.c" + + #define _HPUX_SOURCE + #include + #include + + int main () + { + #if defined(_SC_KERNEL_BITS) + long bits = sysconf(_SC_KERNEL_BITS); + #endif + long cpu = sysconf (_SC_CPU_VERSION); + + switch (cpu) + { + case CPU_PA_RISC1_0: puts ("hppa1.0"); break; + case CPU_PA_RISC1_1: puts ("hppa1.1"); break; + case CPU_PA_RISC2_0: + #if defined(_SC_KERNEL_BITS) + switch (bits) + { + case 64: puts ("hppa2.0w"); break; + case 32: puts ("hppa2.0n"); break; + default: puts ("hppa2.0"); break; + } break; + #else /* !defined(_SC_KERNEL_BITS) */ + puts ("hppa2.0"); break; + #endif + default: puts ("hppa1.0"); break; + } + exit (0); + } +EOF + (CCOPTS="" $CC_FOR_BUILD -o "$dummy" "$dummy.c" 2>/dev/null) && HP_ARCH=$("$dummy") + test -z "$HP_ARCH" && HP_ARCH=hppa + fi ;; + esac + if test "$HP_ARCH" = hppa2.0w + then + set_cc_for_build + + # hppa2.0w-hp-hpux* has a 64-bit kernel and a compiler generating + # 32-bit code. hppa64-hp-hpux* has the same kernel and a compiler + # generating 64-bit code. GNU and HP use different nomenclature: + # + # $ CC_FOR_BUILD=cc ./config.guess + # => hppa2.0w-hp-hpux11.23 + # $ CC_FOR_BUILD="cc +DA2.0w" ./config.guess + # => hppa64-hp-hpux11.23 + + if echo __LP64__ | (CCOPTS="" $CC_FOR_BUILD -E - 2>/dev/null) | + grep -q __LP64__ + then + HP_ARCH=hppa2.0w + else + HP_ARCH=hppa64 + fi + fi + echo "$HP_ARCH"-hp-hpux"$HPUX_REV" + exit ;; + ia64:HP-UX:*:*) + HPUX_REV=$(echo "$UNAME_RELEASE"|sed -e 's/[^.]*.[0B]*//') + echo ia64-hp-hpux"$HPUX_REV" + exit ;; + 3050*:HI-UX:*:*) + set_cc_for_build + sed 's/^ //' << EOF > "$dummy.c" + #include + int + main () + { + long cpu = sysconf (_SC_CPU_VERSION); + /* The order matters, because CPU_IS_HP_MC68K erroneously returns + true for CPU_PA_RISC1_0. CPU_IS_PA_RISC returns correct + results, however. */ + if (CPU_IS_PA_RISC (cpu)) + { + switch (cpu) + { + case CPU_PA_RISC1_0: puts ("hppa1.0-hitachi-hiuxwe2"); break; + case CPU_PA_RISC1_1: puts ("hppa1.1-hitachi-hiuxwe2"); break; + case CPU_PA_RISC2_0: puts ("hppa2.0-hitachi-hiuxwe2"); break; + default: puts ("hppa-hitachi-hiuxwe2"); break; + } + } + else if (CPU_IS_HP_MC68K (cpu)) + puts ("m68k-hitachi-hiuxwe2"); + else puts ("unknown-hitachi-hiuxwe2"); + exit (0); + } +EOF + $CC_FOR_BUILD -o "$dummy" "$dummy.c" && SYSTEM_NAME=$("$dummy") && + { echo "$SYSTEM_NAME"; exit; } + echo unknown-hitachi-hiuxwe2 + exit ;; + 9000/7??:4.3bsd:*:* | 9000/8?[79]:4.3bsd:*:*) + echo hppa1.1-hp-bsd + exit ;; + 9000/8??:4.3bsd:*:*) + echo hppa1.0-hp-bsd + exit ;; + *9??*:MPE/iX:*:* | *3000*:MPE/iX:*:*) + echo hppa1.0-hp-mpeix + exit ;; + hp7??:OSF1:*:* | hp8?[79]:OSF1:*:*) + echo hppa1.1-hp-osf + exit ;; + hp8??:OSF1:*:*) + echo hppa1.0-hp-osf + exit ;; + i*86:OSF1:*:*) + if test -x /usr/sbin/sysversion ; then + echo "$UNAME_MACHINE"-unknown-osf1mk + else + echo "$UNAME_MACHINE"-unknown-osf1 + fi + exit ;; + parisc*:Lites*:*:*) + echo hppa1.1-hp-lites + exit ;; + C1*:ConvexOS:*:* | convex:ConvexOS:C1*:*) + echo c1-convex-bsd + exit ;; + C2*:ConvexOS:*:* | convex:ConvexOS:C2*:*) + if getsysinfo -f scalar_acc + then echo c32-convex-bsd + else echo c2-convex-bsd + fi + exit ;; + C34*:ConvexOS:*:* | convex:ConvexOS:C34*:*) + echo c34-convex-bsd + exit ;; + C38*:ConvexOS:*:* | convex:ConvexOS:C38*:*) + echo c38-convex-bsd + exit ;; + C4*:ConvexOS:*:* | convex:ConvexOS:C4*:*) + echo c4-convex-bsd + exit ;; + CRAY*Y-MP:*:*:*) + echo ymp-cray-unicos"$UNAME_RELEASE" | sed -e 's/\.[^.]*$/.X/' + exit ;; + CRAY*[A-Z]90:*:*:*) + echo "$UNAME_MACHINE"-cray-unicos"$UNAME_RELEASE" \ + | sed -e 's/CRAY.*\([A-Z]90\)/\1/' \ + -e y/ABCDEFGHIJKLMNOPQRSTUVWXYZ/abcdefghijklmnopqrstuvwxyz/ \ + -e 's/\.[^.]*$/.X/' + exit ;; + CRAY*TS:*:*:*) + echo t90-cray-unicos"$UNAME_RELEASE" | sed -e 's/\.[^.]*$/.X/' + exit ;; + CRAY*T3E:*:*:*) + echo alphaev5-cray-unicosmk"$UNAME_RELEASE" | sed -e 's/\.[^.]*$/.X/' + exit ;; + CRAY*SV1:*:*:*) + echo sv1-cray-unicos"$UNAME_RELEASE" | sed -e 's/\.[^.]*$/.X/' + exit ;; + *:UNICOS/mp:*:*) + echo craynv-cray-unicosmp"$UNAME_RELEASE" | sed -e 's/\.[^.]*$/.X/' + exit ;; + F30[01]:UNIX_System_V:*:* | F700:UNIX_System_V:*:*) + FUJITSU_PROC=$(uname -m | tr ABCDEFGHIJKLMNOPQRSTUVWXYZ abcdefghijklmnopqrstuvwxyz) + FUJITSU_SYS=$(uname -p | tr ABCDEFGHIJKLMNOPQRSTUVWXYZ abcdefghijklmnopqrstuvwxyz | sed -e 's/\///') + FUJITSU_REL=$(echo "$UNAME_RELEASE" | sed -e 's/ /_/') + echo "${FUJITSU_PROC}-fujitsu-${FUJITSU_SYS}${FUJITSU_REL}" + exit ;; + 5000:UNIX_System_V:4.*:*) + FUJITSU_SYS=$(uname -p | tr ABCDEFGHIJKLMNOPQRSTUVWXYZ abcdefghijklmnopqrstuvwxyz | sed -e 's/\///') + FUJITSU_REL=$(echo "$UNAME_RELEASE" | tr ABCDEFGHIJKLMNOPQRSTUVWXYZ abcdefghijklmnopqrstuvwxyz | sed -e 's/ /_/') + echo "sparc-fujitsu-${FUJITSU_SYS}${FUJITSU_REL}" + exit ;; + i*86:BSD/386:*:* | i*86:BSD/OS:*:* | *:Ascend\ Embedded/OS:*:*) + echo "$UNAME_MACHINE"-pc-bsdi"$UNAME_RELEASE" + exit ;; + sparc*:BSD/OS:*:*) + echo sparc-unknown-bsdi"$UNAME_RELEASE" + exit ;; + *:BSD/OS:*:*) + echo "$UNAME_MACHINE"-unknown-bsdi"$UNAME_RELEASE" + exit ;; + arm:FreeBSD:*:*) + UNAME_PROCESSOR=$(uname -p) + set_cc_for_build + if echo __ARM_PCS_VFP | $CC_FOR_BUILD -E - 2>/dev/null \ + | grep -q __ARM_PCS_VFP + then + echo "${UNAME_PROCESSOR}"-unknown-freebsd"$(echo ${UNAME_RELEASE}|sed -e 's/[-(].*//')"-gnueabi + else + echo "${UNAME_PROCESSOR}"-unknown-freebsd"$(echo ${UNAME_RELEASE}|sed -e 's/[-(].*//')"-gnueabihf + fi + exit ;; + *:FreeBSD:*:*) + UNAME_PROCESSOR=$(/usr/bin/uname -p) + case $UNAME_PROCESSOR in + amd64) + UNAME_PROCESSOR=x86_64 ;; + i386) + UNAME_PROCESSOR=i586 ;; + esac + echo "$UNAME_PROCESSOR"-unknown-freebsd"$(echo "$UNAME_RELEASE"|sed -e 's/[-(].*//')" + exit ;; + i*:CYGWIN*:*) + echo "$UNAME_MACHINE"-pc-cygwin + exit ;; + *:MINGW64*:*) + echo "$UNAME_MACHINE"-pc-mingw64 + exit ;; + *:MINGW*:*) + echo "$UNAME_MACHINE"-pc-mingw32 + exit ;; + *:MSYS*:*) + echo "$UNAME_MACHINE"-pc-msys + exit ;; + i*:PW*:*) + echo "$UNAME_MACHINE"-pc-pw32 + exit ;; + *:Interix*:*) + case $UNAME_MACHINE in + x86) + echo i586-pc-interix"$UNAME_RELEASE" + exit ;; + authenticamd | genuineintel | EM64T) + echo x86_64-unknown-interix"$UNAME_RELEASE" + exit ;; + IA64) + echo ia64-unknown-interix"$UNAME_RELEASE" + exit ;; + esac ;; + i*:UWIN*:*) + echo "$UNAME_MACHINE"-pc-uwin + exit ;; + amd64:CYGWIN*:*:* | x86_64:CYGWIN*:*:*) + echo x86_64-pc-cygwin + exit ;; + prep*:SunOS:5.*:*) + echo powerpcle-unknown-solaris2"$(echo "$UNAME_RELEASE"|sed -e 's/[^.]*//')" + exit ;; + *:GNU:*:*) + # the GNU system + echo "$(echo "$UNAME_MACHINE"|sed -e 's,[-/].*$,,')-unknown-$LIBC$(echo "$UNAME_RELEASE"|sed -e 's,/.*$,,')" + exit ;; + *:GNU/*:*:*) + # other systems with GNU libc and userland + echo "$UNAME_MACHINE-unknown-$(echo "$UNAME_SYSTEM" | sed 's,^[^/]*/,,' | tr "[:upper:]" "[:lower:]")$(echo "$UNAME_RELEASE"|sed -e 's/[-(].*//')-$LIBC" + exit ;; + *:Minix:*:*) + echo "$UNAME_MACHINE"-unknown-minix + exit ;; + aarch64:Linux:*:*) + echo "$UNAME_MACHINE"-unknown-linux-"$LIBC" + exit ;; + aarch64_be:Linux:*:*) + UNAME_MACHINE=aarch64_be + echo "$UNAME_MACHINE"-unknown-linux-"$LIBC" + exit ;; + alpha:Linux:*:*) + case $(sed -n '/^cpu model/s/^.*: \(.*\)/\1/p' /proc/cpuinfo 2>/dev/null) in + EV5) UNAME_MACHINE=alphaev5 ;; + EV56) UNAME_MACHINE=alphaev56 ;; + PCA56) UNAME_MACHINE=alphapca56 ;; + PCA57) UNAME_MACHINE=alphapca56 ;; + EV6) UNAME_MACHINE=alphaev6 ;; + EV67) UNAME_MACHINE=alphaev67 ;; + EV68*) UNAME_MACHINE=alphaev68 ;; + esac + objdump --private-headers /bin/sh | grep -q ld.so.1 + if test "$?" = 0 ; then LIBC=gnulibc1 ; fi + echo "$UNAME_MACHINE"-unknown-linux-"$LIBC" + exit ;; + arc:Linux:*:* | arceb:Linux:*:* | arc64:Linux:*:*) + echo "$UNAME_MACHINE"-unknown-linux-"$LIBC" + exit ;; + arm*:Linux:*:*) + set_cc_for_build + if echo __ARM_EABI__ | $CC_FOR_BUILD -E - 2>/dev/null \ + | grep -q __ARM_EABI__ + then + echo "$UNAME_MACHINE"-unknown-linux-"$LIBC" + else + if echo __ARM_PCS_VFP | $CC_FOR_BUILD -E - 2>/dev/null \ + | grep -q __ARM_PCS_VFP + then + echo "$UNAME_MACHINE"-unknown-linux-"$LIBC"eabi + else + echo "$UNAME_MACHINE"-unknown-linux-"$LIBC"eabihf + fi + fi + exit ;; + avr32*:Linux:*:*) + echo "$UNAME_MACHINE"-unknown-linux-"$LIBC" + exit ;; + cris:Linux:*:*) + echo "$UNAME_MACHINE"-axis-linux-"$LIBC" + exit ;; + crisv32:Linux:*:*) + echo "$UNAME_MACHINE"-axis-linux-"$LIBC" + exit ;; + e2k:Linux:*:*) + echo "$UNAME_MACHINE"-unknown-linux-"$LIBC" + exit ;; + frv:Linux:*:*) + echo "$UNAME_MACHINE"-unknown-linux-"$LIBC" + exit ;; + hexagon:Linux:*:*) + echo "$UNAME_MACHINE"-unknown-linux-"$LIBC" + exit ;; + i*86:Linux:*:*) + echo "$UNAME_MACHINE"-pc-linux-"$LIBC" + exit ;; + ia64:Linux:*:*) + echo "$UNAME_MACHINE"-unknown-linux-"$LIBC" + exit ;; + k1om:Linux:*:*) + echo "$UNAME_MACHINE"-unknown-linux-"$LIBC" + exit ;; + loongarch32:Linux:*:* | loongarch64:Linux:*:* | loongarchx32:Linux:*:*) + echo "$UNAME_MACHINE"-unknown-linux-"$LIBC" + exit ;; + m32r*:Linux:*:*) + echo "$UNAME_MACHINE"-unknown-linux-"$LIBC" + exit ;; + m68*:Linux:*:*) + echo "$UNAME_MACHINE"-unknown-linux-"$LIBC" + exit ;; + mips:Linux:*:* | mips64:Linux:*:*) + set_cc_for_build + IS_GLIBC=0 + test x"${LIBC}" = xgnu && IS_GLIBC=1 + sed 's/^ //' << EOF > "$dummy.c" + #undef CPU + #undef mips + #undef mipsel + #undef mips64 + #undef mips64el + #if ${IS_GLIBC} && defined(_ABI64) + LIBCABI=gnuabi64 + #else + #if ${IS_GLIBC} && defined(_ABIN32) + LIBCABI=gnuabin32 + #else + LIBCABI=${LIBC} + #endif + #endif + + #if ${IS_GLIBC} && defined(__mips64) && defined(__mips_isa_rev) && __mips_isa_rev>=6 + CPU=mipsisa64r6 + #else + #if ${IS_GLIBC} && !defined(__mips64) && defined(__mips_isa_rev) && __mips_isa_rev>=6 + CPU=mipsisa32r6 + #else + #if defined(__mips64) + CPU=mips64 + #else + CPU=mips + #endif + #endif + #endif + + #if defined(__MIPSEL__) || defined(__MIPSEL) || defined(_MIPSEL) || defined(MIPSEL) + MIPS_ENDIAN=el + #else + #if defined(__MIPSEB__) || defined(__MIPSEB) || defined(_MIPSEB) || defined(MIPSEB) + MIPS_ENDIAN= + #else + MIPS_ENDIAN= + #endif + #endif +EOF + eval "$($CC_FOR_BUILD -E "$dummy.c" 2>/dev/null | grep '^CPU\|^MIPS_ENDIAN\|^LIBCABI')" + test "x$CPU" != x && { echo "$CPU${MIPS_ENDIAN}-unknown-linux-$LIBCABI"; exit; } + ;; + mips64el:Linux:*:*) + echo "$UNAME_MACHINE"-unknown-linux-"$LIBC" + exit ;; + openrisc*:Linux:*:*) + echo or1k-unknown-linux-"$LIBC" + exit ;; + or32:Linux:*:* | or1k*:Linux:*:*) + echo "$UNAME_MACHINE"-unknown-linux-"$LIBC" + exit ;; + padre:Linux:*:*) + echo sparc-unknown-linux-"$LIBC" + exit ;; + parisc64:Linux:*:* | hppa64:Linux:*:*) + echo hppa64-unknown-linux-"$LIBC" + exit ;; + parisc:Linux:*:* | hppa:Linux:*:*) + # Look for CPU level + case $(grep '^cpu[^a-z]*:' /proc/cpuinfo 2>/dev/null | cut -d' ' -f2) in + PA7*) echo hppa1.1-unknown-linux-"$LIBC" ;; + PA8*) echo hppa2.0-unknown-linux-"$LIBC" ;; + *) echo hppa-unknown-linux-"$LIBC" ;; + esac + exit ;; + ppc64:Linux:*:*) + echo powerpc64-unknown-linux-"$LIBC" + exit ;; + ppc:Linux:*:*) + echo powerpc-unknown-linux-"$LIBC" + exit ;; + ppc64le:Linux:*:*) + echo powerpc64le-unknown-linux-"$LIBC" + exit ;; + ppcle:Linux:*:*) + echo powerpcle-unknown-linux-"$LIBC" + exit ;; + riscv32:Linux:*:* | riscv32be:Linux:*:* | riscv64:Linux:*:* | riscv64be:Linux:*:*) + echo "$UNAME_MACHINE"-unknown-linux-"$LIBC" + exit ;; + s390:Linux:*:* | s390x:Linux:*:*) + echo "$UNAME_MACHINE"-ibm-linux-"$LIBC" + exit ;; + sh64*:Linux:*:*) + echo "$UNAME_MACHINE"-unknown-linux-"$LIBC" + exit ;; + sh*:Linux:*:*) + echo "$UNAME_MACHINE"-unknown-linux-"$LIBC" + exit ;; + sparc:Linux:*:* | sparc64:Linux:*:*) + echo "$UNAME_MACHINE"-unknown-linux-"$LIBC" + exit ;; + tile*:Linux:*:*) + echo "$UNAME_MACHINE"-unknown-linux-"$LIBC" + exit ;; + vax:Linux:*:*) + echo "$UNAME_MACHINE"-dec-linux-"$LIBC" + exit ;; + x86_64:Linux:*:*) + set_cc_for_build + LIBCABI=$LIBC + if test "$CC_FOR_BUILD" != no_compiler_found; then + if (echo '#ifdef __ILP32__'; echo IS_X32; echo '#endif') | \ + (CCOPTS="" $CC_FOR_BUILD -E - 2>/dev/null) | \ + grep IS_X32 >/dev/null + then + LIBCABI="$LIBC"x32 + fi + fi + echo "$UNAME_MACHINE"-pc-linux-"$LIBCABI" + exit ;; + xtensa*:Linux:*:*) + echo "$UNAME_MACHINE"-unknown-linux-"$LIBC" + exit ;; + i*86:DYNIX/ptx:4*:*) + # ptx 4.0 does uname -s correctly, with DYNIX/ptx in there. + # earlier versions are messed up and put the nodename in both + # sysname and nodename. + echo i386-sequent-sysv4 + exit ;; + i*86:UNIX_SV:4.2MP:2.*) + # Unixware is an offshoot of SVR4, but it has its own version + # number series starting with 2... + # I am not positive that other SVR4 systems won't match this, + # I just have to hope. -- rms. + # Use sysv4.2uw... so that sysv4* matches it. + echo "$UNAME_MACHINE"-pc-sysv4.2uw"$UNAME_VERSION" + exit ;; + i*86:OS/2:*:*) + # If we were able to find `uname', then EMX Unix compatibility + # is probably installed. + echo "$UNAME_MACHINE"-pc-os2-emx + exit ;; + i*86:XTS-300:*:STOP) + echo "$UNAME_MACHINE"-unknown-stop + exit ;; + i*86:atheos:*:*) + echo "$UNAME_MACHINE"-unknown-atheos + exit ;; + i*86:syllable:*:*) + echo "$UNAME_MACHINE"-pc-syllable + exit ;; + i*86:LynxOS:2.*:* | i*86:LynxOS:3.[01]*:* | i*86:LynxOS:4.[02]*:*) + echo i386-unknown-lynxos"$UNAME_RELEASE" + exit ;; + i*86:*DOS:*:*) + echo "$UNAME_MACHINE"-pc-msdosdjgpp + exit ;; + i*86:*:4.*:*) + UNAME_REL=$(echo "$UNAME_RELEASE" | sed 's/\/MP$//') + if grep Novell /usr/include/link.h >/dev/null 2>/dev/null; then + echo "$UNAME_MACHINE"-univel-sysv"$UNAME_REL" + else + echo "$UNAME_MACHINE"-pc-sysv"$UNAME_REL" + fi + exit ;; + i*86:*:5:[678]*) + # UnixWare 7.x, OpenUNIX and OpenServer 6. + case $(/bin/uname -X | grep "^Machine") in + *486*) UNAME_MACHINE=i486 ;; + *Pentium) UNAME_MACHINE=i586 ;; + *Pent*|*Celeron) UNAME_MACHINE=i686 ;; + esac + echo "$UNAME_MACHINE-unknown-sysv${UNAME_RELEASE}${UNAME_SYSTEM}${UNAME_VERSION}" + exit ;; + i*86:*:3.2:*) + if test -f /usr/options/cb.name; then + UNAME_REL=$(sed -n 's/.*Version //p' /dev/null >/dev/null ; then + UNAME_REL=$( (/bin/uname -X|grep Release|sed -e 's/.*= //')) + (/bin/uname -X|grep i80486 >/dev/null) && UNAME_MACHINE=i486 + (/bin/uname -X|grep '^Machine.*Pentium' >/dev/null) \ + && UNAME_MACHINE=i586 + (/bin/uname -X|grep '^Machine.*Pent *II' >/dev/null) \ + && UNAME_MACHINE=i686 + (/bin/uname -X|grep '^Machine.*Pentium Pro' >/dev/null) \ + && UNAME_MACHINE=i686 + echo "$UNAME_MACHINE"-pc-sco"$UNAME_REL" + else + echo "$UNAME_MACHINE"-pc-sysv32 + fi + exit ;; + pc:*:*:*) + # Left here for compatibility: + # uname -m prints for DJGPP always 'pc', but it prints nothing about + # the processor, so we play safe by assuming i586. + # Note: whatever this is, it MUST be the same as what config.sub + # prints for the "djgpp" host, or else GDB configure will decide that + # this is a cross-build. + echo i586-pc-msdosdjgpp + exit ;; + Intel:Mach:3*:*) + echo i386-pc-mach3 + exit ;; + paragon:*:*:*) + echo i860-intel-osf1 + exit ;; + i860:*:4.*:*) # i860-SVR4 + if grep Stardent /usr/include/sys/uadmin.h >/dev/null 2>&1 ; then + echo i860-stardent-sysv"$UNAME_RELEASE" # Stardent Vistra i860-SVR4 + else # Add other i860-SVR4 vendors below as they are discovered. + echo i860-unknown-sysv"$UNAME_RELEASE" # Unknown i860-SVR4 + fi + exit ;; + mini*:CTIX:SYS*5:*) + # "miniframe" + echo m68010-convergent-sysv + exit ;; + mc68k:UNIX:SYSTEM5:3.51m) + echo m68k-convergent-sysv + exit ;; + M680?0:D-NIX:5.3:*) + echo m68k-diab-dnix + exit ;; + M68*:*:R3V[5678]*:*) + test -r /sysV68 && { echo 'm68k-motorola-sysv'; exit; } ;; + 3[345]??:*:4.0:3.0 | 3[34]??A:*:4.0:3.0 | 3[34]??,*:*:4.0:3.0 | 3[34]??/*:*:4.0:3.0 | 4400:*:4.0:3.0 | 4850:*:4.0:3.0 | SKA40:*:4.0:3.0 | SDS2:*:4.0:3.0 | SHG2:*:4.0:3.0 | S7501*:*:4.0:3.0) + OS_REL='' + test -r /etc/.relid \ + && OS_REL=.$(sed -n 's/[^ ]* [^ ]* \([0-9][0-9]\).*/\1/p' < /etc/.relid) + /bin/uname -p 2>/dev/null | grep 86 >/dev/null \ + && { echo i486-ncr-sysv4.3"$OS_REL"; exit; } + /bin/uname -p 2>/dev/null | /bin/grep entium >/dev/null \ + && { echo i586-ncr-sysv4.3"$OS_REL"; exit; } ;; + 3[34]??:*:4.0:* | 3[34]??,*:*:4.0:*) + /bin/uname -p 2>/dev/null | grep 86 >/dev/null \ + && { echo i486-ncr-sysv4; exit; } ;; + NCR*:*:4.2:* | MPRAS*:*:4.2:*) + OS_REL='.3' + test -r /etc/.relid \ + && OS_REL=.$(sed -n 's/[^ ]* [^ ]* \([0-9][0-9]\).*/\1/p' < /etc/.relid) + /bin/uname -p 2>/dev/null | grep 86 >/dev/null \ + && { echo i486-ncr-sysv4.3"$OS_REL"; exit; } + /bin/uname -p 2>/dev/null | /bin/grep entium >/dev/null \ + && { echo i586-ncr-sysv4.3"$OS_REL"; exit; } + /bin/uname -p 2>/dev/null | /bin/grep pteron >/dev/null \ + && { echo i586-ncr-sysv4.3"$OS_REL"; exit; } ;; + m68*:LynxOS:2.*:* | m68*:LynxOS:3.0*:*) + echo m68k-unknown-lynxos"$UNAME_RELEASE" + exit ;; + mc68030:UNIX_System_V:4.*:*) + echo m68k-atari-sysv4 + exit ;; + TSUNAMI:LynxOS:2.*:*) + echo sparc-unknown-lynxos"$UNAME_RELEASE" + exit ;; + rs6000:LynxOS:2.*:*) + echo rs6000-unknown-lynxos"$UNAME_RELEASE" + exit ;; + PowerPC:LynxOS:2.*:* | PowerPC:LynxOS:3.[01]*:* | PowerPC:LynxOS:4.[02]*:*) + echo powerpc-unknown-lynxos"$UNAME_RELEASE" + exit ;; + SM[BE]S:UNIX_SV:*:*) + echo mips-dde-sysv"$UNAME_RELEASE" + exit ;; + RM*:ReliantUNIX-*:*:*) + echo mips-sni-sysv4 + exit ;; + RM*:SINIX-*:*:*) + echo mips-sni-sysv4 + exit ;; + *:SINIX-*:*:*) + if uname -p 2>/dev/null >/dev/null ; then + UNAME_MACHINE=$( (uname -p) 2>/dev/null) + echo "$UNAME_MACHINE"-sni-sysv4 + else + echo ns32k-sni-sysv + fi + exit ;; + PENTIUM:*:4.0*:*) # Unisys `ClearPath HMP IX 4000' SVR4/MP effort + # says + echo i586-unisys-sysv4 + exit ;; + *:UNIX_System_V:4*:FTX*) + # From Gerald Hewes . + # How about differentiating between stratus architectures? -djm + echo hppa1.1-stratus-sysv4 + exit ;; + *:*:*:FTX*) + # From seanf@swdc.stratus.com. + echo i860-stratus-sysv4 + exit ;; + i*86:VOS:*:*) + # From Paul.Green@stratus.com. + echo "$UNAME_MACHINE"-stratus-vos + exit ;; + *:VOS:*:*) + # From Paul.Green@stratus.com. + echo hppa1.1-stratus-vos + exit ;; + mc68*:A/UX:*:*) + echo m68k-apple-aux"$UNAME_RELEASE" + exit ;; + news*:NEWS-OS:6*:*) + echo mips-sony-newsos6 + exit ;; + R[34]000:*System_V*:*:* | R4000:UNIX_SYSV:*:* | R*000:UNIX_SV:*:*) + if test -d /usr/nec; then + echo mips-nec-sysv"$UNAME_RELEASE" + else + echo mips-unknown-sysv"$UNAME_RELEASE" + fi + exit ;; + BeBox:BeOS:*:*) # BeOS running on hardware made by Be, PPC only. + echo powerpc-be-beos + exit ;; + BeMac:BeOS:*:*) # BeOS running on Mac or Mac clone, PPC only. + echo powerpc-apple-beos + exit ;; + BePC:BeOS:*:*) # BeOS running on Intel PC compatible. + echo i586-pc-beos + exit ;; + BePC:Haiku:*:*) # Haiku running on Intel PC compatible. + echo i586-pc-haiku + exit ;; + x86_64:Haiku:*:*) + echo x86_64-unknown-haiku + exit ;; + SX-4:SUPER-UX:*:*) + echo sx4-nec-superux"$UNAME_RELEASE" + exit ;; + SX-5:SUPER-UX:*:*) + echo sx5-nec-superux"$UNAME_RELEASE" + exit ;; + SX-6:SUPER-UX:*:*) + echo sx6-nec-superux"$UNAME_RELEASE" + exit ;; + SX-7:SUPER-UX:*:*) + echo sx7-nec-superux"$UNAME_RELEASE" + exit ;; + SX-8:SUPER-UX:*:*) + echo sx8-nec-superux"$UNAME_RELEASE" + exit ;; + SX-8R:SUPER-UX:*:*) + echo sx8r-nec-superux"$UNAME_RELEASE" + exit ;; + SX-ACE:SUPER-UX:*:*) + echo sxace-nec-superux"$UNAME_RELEASE" + exit ;; + Power*:Rhapsody:*:*) + echo powerpc-apple-rhapsody"$UNAME_RELEASE" + exit ;; + *:Rhapsody:*:*) + echo "$UNAME_MACHINE"-apple-rhapsody"$UNAME_RELEASE" + exit ;; + arm64:Darwin:*:*) + echo aarch64-apple-darwin"$UNAME_RELEASE" + exit ;; + *:Darwin:*:*) + UNAME_PROCESSOR=$(uname -p) + case $UNAME_PROCESSOR in + unknown) UNAME_PROCESSOR=powerpc ;; + esac + if command -v xcode-select > /dev/null 2> /dev/null && \ + ! xcode-select --print-path > /dev/null 2> /dev/null ; then + # Avoid executing cc if there is no toolchain installed as + # cc will be a stub that puts up a graphical alert + # prompting the user to install developer tools. + CC_FOR_BUILD=no_compiler_found + else + set_cc_for_build + fi + if test "$CC_FOR_BUILD" != no_compiler_found; then + if (echo '#ifdef __LP64__'; echo IS_64BIT_ARCH; echo '#endif') | \ + (CCOPTS="" $CC_FOR_BUILD -E - 2>/dev/null) | \ + grep IS_64BIT_ARCH >/dev/null + then + case $UNAME_PROCESSOR in + i386) UNAME_PROCESSOR=x86_64 ;; + powerpc) UNAME_PROCESSOR=powerpc64 ;; + esac + fi + # On 10.4-10.6 one might compile for PowerPC via gcc -arch ppc + if (echo '#ifdef __POWERPC__'; echo IS_PPC; echo '#endif') | \ + (CCOPTS="" $CC_FOR_BUILD -E - 2>/dev/null) | \ + grep IS_PPC >/dev/null + then + UNAME_PROCESSOR=powerpc + fi + elif test "$UNAME_PROCESSOR" = i386 ; then + # uname -m returns i386 or x86_64 + UNAME_PROCESSOR=$UNAME_MACHINE + fi + echo "$UNAME_PROCESSOR"-apple-darwin"$UNAME_RELEASE" + exit ;; + *:procnto*:*:* | *:QNX:[0123456789]*:*) + UNAME_PROCESSOR=$(uname -p) + if test "$UNAME_PROCESSOR" = x86; then + UNAME_PROCESSOR=i386 + UNAME_MACHINE=pc + fi + echo "$UNAME_PROCESSOR"-"$UNAME_MACHINE"-nto-qnx"$UNAME_RELEASE" + exit ;; + *:QNX:*:4*) + echo i386-pc-qnx + exit ;; + NEO-*:NONSTOP_KERNEL:*:*) + echo neo-tandem-nsk"$UNAME_RELEASE" + exit ;; + NSE-*:NONSTOP_KERNEL:*:*) + echo nse-tandem-nsk"$UNAME_RELEASE" + exit ;; + NSR-*:NONSTOP_KERNEL:*:*) + echo nsr-tandem-nsk"$UNAME_RELEASE" + exit ;; + NSV-*:NONSTOP_KERNEL:*:*) + echo nsv-tandem-nsk"$UNAME_RELEASE" + exit ;; + NSX-*:NONSTOP_KERNEL:*:*) + echo nsx-tandem-nsk"$UNAME_RELEASE" + exit ;; + *:NonStop-UX:*:*) + echo mips-compaq-nonstopux + exit ;; + BS2000:POSIX*:*:*) + echo bs2000-siemens-sysv + exit ;; + DS/*:UNIX_System_V:*:*) + echo "$UNAME_MACHINE"-"$UNAME_SYSTEM"-"$UNAME_RELEASE" + exit ;; + *:Plan9:*:*) + # "uname -m" is not consistent, so use $cputype instead. 386 + # is converted to i386 for consistency with other x86 + # operating systems. + if test "${cputype-}" = 386; then + UNAME_MACHINE=i386 + elif test "x${cputype-}" != x; then + UNAME_MACHINE="$cputype" + fi + echo "$UNAME_MACHINE"-unknown-plan9 + exit ;; + *:TOPS-10:*:*) + echo pdp10-unknown-tops10 + exit ;; + *:TENEX:*:*) + echo pdp10-unknown-tenex + exit ;; + KS10:TOPS-20:*:* | KL10:TOPS-20:*:* | TYPE4:TOPS-20:*:*) + echo pdp10-dec-tops20 + exit ;; + XKL-1:TOPS-20:*:* | TYPE5:TOPS-20:*:*) + echo pdp10-xkl-tops20 + exit ;; + *:TOPS-20:*:*) + echo pdp10-unknown-tops20 + exit ;; + *:ITS:*:*) + echo pdp10-unknown-its + exit ;; + SEI:*:*:SEIUX) + echo mips-sei-seiux"$UNAME_RELEASE" + exit ;; + *:DragonFly:*:*) + echo "$UNAME_MACHINE"-unknown-dragonfly"$(echo "$UNAME_RELEASE"|sed -e 's/[-(].*//')" + exit ;; + *:*VMS:*:*) + UNAME_MACHINE=$( (uname -p) 2>/dev/null) + case $UNAME_MACHINE in + A*) echo alpha-dec-vms ; exit ;; + I*) echo ia64-dec-vms ; exit ;; + V*) echo vax-dec-vms ; exit ;; + esac ;; + *:XENIX:*:SysV) + echo i386-pc-xenix + exit ;; + i*86:skyos:*:*) + echo "$UNAME_MACHINE"-pc-skyos"$(echo "$UNAME_RELEASE" | sed -e 's/ .*$//')" + exit ;; + i*86:rdos:*:*) + echo "$UNAME_MACHINE"-pc-rdos + exit ;; + *:AROS:*:*) + echo "$UNAME_MACHINE"-unknown-aros + exit ;; + x86_64:VMkernel:*:*) + echo "$UNAME_MACHINE"-unknown-esx + exit ;; + amd64:Isilon\ OneFS:*:*) + echo x86_64-unknown-onefs + exit ;; + *:Unleashed:*:*) + echo "$UNAME_MACHINE"-unknown-unleashed"$UNAME_RELEASE" + exit ;; +esac + +# No uname command or uname output not recognized. +set_cc_for_build +cat > "$dummy.c" < +#include +#endif +#if defined(ultrix) || defined(_ultrix) || defined(__ultrix) || defined(__ultrix__) +#if defined (vax) || defined (__vax) || defined (__vax__) || defined(mips) || defined(__mips) || defined(__mips__) || defined(MIPS) || defined(__MIPS__) +#include +#if defined(_SIZE_T_) || defined(SIGLOST) +#include +#endif +#endif +#endif +main () +{ +#if defined (sony) +#if defined (MIPSEB) + /* BFD wants "bsd" instead of "newsos". Perhaps BFD should be changed, + I don't know.... */ + printf ("mips-sony-bsd\n"); exit (0); +#else +#include + printf ("m68k-sony-newsos%s\n", +#ifdef NEWSOS4 + "4" +#else + "" +#endif + ); exit (0); +#endif +#endif + +#if defined (NeXT) +#if !defined (__ARCHITECTURE__) +#define __ARCHITECTURE__ "m68k" +#endif + int version; + version=$( (hostinfo | sed -n 's/.*NeXT Mach \([0-9]*\).*/\1/p') 2>/dev/null); + if (version < 4) + printf ("%s-next-nextstep%d\n", __ARCHITECTURE__, version); + else + printf ("%s-next-openstep%d\n", __ARCHITECTURE__, version); + exit (0); +#endif + +#if defined (MULTIMAX) || defined (n16) +#if defined (UMAXV) + printf ("ns32k-encore-sysv\n"); exit (0); +#else +#if defined (CMU) + printf ("ns32k-encore-mach\n"); exit (0); +#else + printf ("ns32k-encore-bsd\n"); exit (0); +#endif +#endif +#endif + +#if defined (__386BSD__) + printf ("i386-pc-bsd\n"); exit (0); +#endif + +#if defined (sequent) +#if defined (i386) + printf ("i386-sequent-dynix\n"); exit (0); +#endif +#if defined (ns32000) + printf ("ns32k-sequent-dynix\n"); exit (0); +#endif +#endif + +#if defined (_SEQUENT_) + struct utsname un; + + uname(&un); + if (strncmp(un.version, "V2", 2) == 0) { + printf ("i386-sequent-ptx2\n"); exit (0); + } + if (strncmp(un.version, "V1", 2) == 0) { /* XXX is V1 correct? */ + printf ("i386-sequent-ptx1\n"); exit (0); + } + printf ("i386-sequent-ptx\n"); exit (0); +#endif + +#if defined (vax) +#if !defined (ultrix) +#include +#if defined (BSD) +#if BSD == 43 + printf ("vax-dec-bsd4.3\n"); exit (0); +#else +#if BSD == 199006 + printf ("vax-dec-bsd4.3reno\n"); exit (0); +#else + printf ("vax-dec-bsd\n"); exit (0); +#endif +#endif +#else + printf ("vax-dec-bsd\n"); exit (0); +#endif +#else +#if defined(_SIZE_T_) || defined(SIGLOST) + struct utsname un; + uname (&un); + printf ("vax-dec-ultrix%s\n", un.release); exit (0); +#else + printf ("vax-dec-ultrix\n"); exit (0); +#endif +#endif +#endif +#if defined(ultrix) || defined(_ultrix) || defined(__ultrix) || defined(__ultrix__) +#if defined(mips) || defined(__mips) || defined(__mips__) || defined(MIPS) || defined(__MIPS__) +#if defined(_SIZE_T_) || defined(SIGLOST) + struct utsname *un; + uname (&un); + printf ("mips-dec-ultrix%s\n", un.release); exit (0); +#else + printf ("mips-dec-ultrix\n"); exit (0); +#endif +#endif +#endif + +#if defined (alliant) && defined (i860) + printf ("i860-alliant-bsd\n"); exit (0); +#endif + + exit (1); +} +EOF + +$CC_FOR_BUILD -o "$dummy" "$dummy.c" 2>/dev/null && SYSTEM_NAME=$($dummy) && + { echo "$SYSTEM_NAME"; exit; } + +# Apollos put the system type in the environment. +test -d /usr/apollo && { echo "$ISP-apollo-$SYSTYPE"; exit; } + +echo "$0: unable to guess system type" >&2 + +case $UNAME_MACHINE:$UNAME_SYSTEM in + mips:Linux | mips64:Linux) + # If we got here on MIPS GNU/Linux, output extra information. + cat >&2 <&2 <&2 </dev/null || echo unknown) +uname -r = $( (uname -r) 2>/dev/null || echo unknown) +uname -s = $( (uname -s) 2>/dev/null || echo unknown) +uname -v = $( (uname -v) 2>/dev/null || echo unknown) + +/usr/bin/uname -p = $( (/usr/bin/uname -p) 2>/dev/null) +/bin/uname -X = $( (/bin/uname -X) 2>/dev/null) + +hostinfo = $( (hostinfo) 2>/dev/null) +/bin/universe = $( (/bin/universe) 2>/dev/null) +/usr/bin/arch -k = $( (/usr/bin/arch -k) 2>/dev/null) +/bin/arch = $( (/bin/arch) 2>/dev/null) +/usr/bin/oslevel = $( (/usr/bin/oslevel) 2>/dev/null) +/usr/convex/getsysinfo = $( (/usr/convex/getsysinfo) 2>/dev/null) + +UNAME_MACHINE = "$UNAME_MACHINE" +UNAME_RELEASE = "$UNAME_RELEASE" +UNAME_SYSTEM = "$UNAME_SYSTEM" +UNAME_VERSION = "$UNAME_VERSION" +EOF +fi + +exit 1 + +# Local variables: +# eval: (add-hook 'before-save-hook 'time-stamp) +# time-stamp-start: "timestamp='" +# time-stamp-format: "%:y-%02m-%02d" +# time-stamp-end: "'" +# End: diff --git a/third_party/sox/patch/config.sub b/third_party/sox/patch/config.sub new file mode 100644 index 0000000000000000000000000000000000000000..7384e9198b4051ce9e3e079d50d8adb5060b1771 --- /dev/null +++ b/third_party/sox/patch/config.sub @@ -0,0 +1,1864 @@ +#! /bin/sh +# Configuration validation subroutine script. +# Copyright 1992-2021 Free Software Foundation, Inc. + +timestamp='2021-04-30' + +# This file is free software; you can redistribute it and/or modify it +# under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, see . +# +# As a special exception to the GNU General Public License, if you +# distribute this file as part of a program that contains a +# configuration script generated by Autoconf, you may include it under +# the same distribution terms that you use for the rest of that +# program. This Exception is an additional permission under section 7 +# of the GNU General Public License, version 3 ("GPLv3"). + + +# Please send patches to . +# +# Configuration subroutine to validate and canonicalize a configuration type. +# Supply the specified configuration type as an argument. +# If it is invalid, we print an error message on stderr and exit with code 1. +# Otherwise, we print the canonical config type on stdout and succeed. + +# You can get the latest version of this script from: +# https://git.savannah.gnu.org/cgit/config.git/plain/config.sub + +# This file is supposed to be the same for all GNU packages +# and recognize all the CPU types, system types and aliases +# that are meaningful with *any* GNU software. +# Each package is responsible for reporting which valid configurations +# it does not support. The user should be able to distinguish +# a failure to support a valid configuration from a meaningless +# configuration. + +# The goal of this file is to map all the various variations of a given +# machine specification into a single specification in the form: +# CPU_TYPE-MANUFACTURER-OPERATING_SYSTEM +# or in some cases, the newer four-part form: +# CPU_TYPE-MANUFACTURER-KERNEL-OPERATING_SYSTEM +# It is wrong to echo any other type of specification. + +me=$(echo "$0" | sed -e 's,.*/,,') + +usage="\ +Usage: $0 [OPTION] CPU-MFR-OPSYS or ALIAS + +Canonicalize a configuration name. + +Options: + -h, --help print this help, then exit + -t, --time-stamp print date of last modification, then exit + -v, --version print version number, then exit + +Report bugs and patches to ." + +version="\ +GNU config.sub ($timestamp) + +Copyright 1992-2021 Free Software Foundation, Inc. + +This is free software; see the source for copying conditions. There is NO +warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE." + +help=" +Try \`$me --help' for more information." + +# Parse command line +while test $# -gt 0 ; do + case $1 in + --time-stamp | --time* | -t ) + echo "$timestamp" ; exit ;; + --version | -v ) + echo "$version" ; exit ;; + --help | --h* | -h ) + echo "$usage"; exit ;; + -- ) # Stop option processing + shift; break ;; + - ) # Use stdin as input. + break ;; + -* ) + echo "$me: invalid option $1$help" >&2 + exit 1 ;; + + *local*) + # First pass through any local machine types. + echo "$1" + exit ;; + + * ) + break ;; + esac +done + +case $# in + 0) echo "$me: missing argument$help" >&2 + exit 1;; + 1) ;; + *) echo "$me: too many arguments$help" >&2 + exit 1;; +esac + +# Split fields of configuration type +# shellcheck disable=SC2162 +IFS="-" read field1 field2 field3 field4 <&2 + exit 1 + ;; + *-*-*-*) + basic_machine=$field1-$field2 + basic_os=$field3-$field4 + ;; + *-*-*) + # Ambiguous whether COMPANY is present, or skipped and KERNEL-OS is two + # parts + maybe_os=$field2-$field3 + case $maybe_os in + nto-qnx* | linux-* | uclinux-uclibc* \ + | uclinux-gnu* | kfreebsd*-gnu* | knetbsd*-gnu* | netbsd*-gnu* \ + | netbsd*-eabi* | kopensolaris*-gnu* | cloudabi*-eabi* \ + | storm-chaos* | os2-emx* | rtmk-nova*) + basic_machine=$field1 + basic_os=$maybe_os + ;; + android-linux) + basic_machine=$field1-unknown + basic_os=linux-android + ;; + *) + basic_machine=$field1-$field2 + basic_os=$field3 + ;; + esac + ;; + *-*) + # A lone config we happen to match not fitting any pattern + case $field1-$field2 in + decstation-3100) + basic_machine=mips-dec + basic_os= + ;; + *-*) + # Second component is usually, but not always the OS + case $field2 in + # Prevent following clause from handling this valid os + sun*os*) + basic_machine=$field1 + basic_os=$field2 + ;; + # Manufacturers + dec* | mips* | sequent* | encore* | pc533* | sgi* | sony* \ + | att* | 7300* | 3300* | delta* | motorola* | sun[234]* \ + | unicom* | ibm* | next | hp | isi* | apollo | altos* \ + | convergent* | ncr* | news | 32* | 3600* | 3100* \ + | hitachi* | c[123]* | convex* | sun | crds | omron* | dg \ + | ultra | tti* | harris | dolphin | highlevel | gould \ + | cbm | ns | masscomp | apple | axis | knuth | cray \ + | microblaze* | sim | cisco \ + | oki | wec | wrs | winbond) + basic_machine=$field1-$field2 + basic_os= + ;; + *) + basic_machine=$field1 + basic_os=$field2 + ;; + esac + ;; + esac + ;; + *) + # Convert single-component short-hands not valid as part of + # multi-component configurations. + case $field1 in + 386bsd) + basic_machine=i386-pc + basic_os=bsd + ;; + a29khif) + basic_machine=a29k-amd + basic_os=udi + ;; + adobe68k) + basic_machine=m68010-adobe + basic_os=scout + ;; + alliant) + basic_machine=fx80-alliant + basic_os= + ;; + altos | altos3068) + basic_machine=m68k-altos + basic_os= + ;; + am29k) + basic_machine=a29k-none + basic_os=bsd + ;; + amdahl) + basic_machine=580-amdahl + basic_os=sysv + ;; + amiga) + basic_machine=m68k-unknown + basic_os= + ;; + amigaos | amigados) + basic_machine=m68k-unknown + basic_os=amigaos + ;; + amigaunix | amix) + basic_machine=m68k-unknown + basic_os=sysv4 + ;; + apollo68) + basic_machine=m68k-apollo + basic_os=sysv + ;; + apollo68bsd) + basic_machine=m68k-apollo + basic_os=bsd + ;; + aros) + basic_machine=i386-pc + basic_os=aros + ;; + aux) + basic_machine=m68k-apple + basic_os=aux + ;; + balance) + basic_machine=ns32k-sequent + basic_os=dynix + ;; + blackfin) + basic_machine=bfin-unknown + basic_os=linux + ;; + cegcc) + basic_machine=arm-unknown + basic_os=cegcc + ;; + convex-c1) + basic_machine=c1-convex + basic_os=bsd + ;; + convex-c2) + basic_machine=c2-convex + basic_os=bsd + ;; + convex-c32) + basic_machine=c32-convex + basic_os=bsd + ;; + convex-c34) + basic_machine=c34-convex + basic_os=bsd + ;; + convex-c38) + basic_machine=c38-convex + basic_os=bsd + ;; + cray) + basic_machine=j90-cray + basic_os=unicos + ;; + crds | unos) + basic_machine=m68k-crds + basic_os= + ;; + da30) + basic_machine=m68k-da30 + basic_os= + ;; + decstation | pmax | pmin | dec3100 | decstatn) + basic_machine=mips-dec + basic_os= + ;; + delta88) + basic_machine=m88k-motorola + basic_os=sysv3 + ;; + dicos) + basic_machine=i686-pc + basic_os=dicos + ;; + djgpp) + basic_machine=i586-pc + basic_os=msdosdjgpp + ;; + ebmon29k) + basic_machine=a29k-amd + basic_os=ebmon + ;; + es1800 | OSE68k | ose68k | ose | OSE) + basic_machine=m68k-ericsson + basic_os=ose + ;; + gmicro) + basic_machine=tron-gmicro + basic_os=sysv + ;; + go32) + basic_machine=i386-pc + basic_os=go32 + ;; + h8300hms) + basic_machine=h8300-hitachi + basic_os=hms + ;; + h8300xray) + basic_machine=h8300-hitachi + basic_os=xray + ;; + h8500hms) + basic_machine=h8500-hitachi + basic_os=hms + ;; + harris) + basic_machine=m88k-harris + basic_os=sysv3 + ;; + hp300 | hp300hpux) + basic_machine=m68k-hp + basic_os=hpux + ;; + hp300bsd) + basic_machine=m68k-hp + basic_os=bsd + ;; + hppaosf) + basic_machine=hppa1.1-hp + basic_os=osf + ;; + hppro) + basic_machine=hppa1.1-hp + basic_os=proelf + ;; + i386mach) + basic_machine=i386-mach + basic_os=mach + ;; + isi68 | isi) + basic_machine=m68k-isi + basic_os=sysv + ;; + m68knommu) + basic_machine=m68k-unknown + basic_os=linux + ;; + magnum | m3230) + basic_machine=mips-mips + basic_os=sysv + ;; + merlin) + basic_machine=ns32k-utek + basic_os=sysv + ;; + mingw64) + basic_machine=x86_64-pc + basic_os=mingw64 + ;; + mingw32) + basic_machine=i686-pc + basic_os=mingw32 + ;; + mingw32ce) + basic_machine=arm-unknown + basic_os=mingw32ce + ;; + monitor) + basic_machine=m68k-rom68k + basic_os=coff + ;; + morphos) + basic_machine=powerpc-unknown + basic_os=morphos + ;; + moxiebox) + basic_machine=moxie-unknown + basic_os=moxiebox + ;; + msdos) + basic_machine=i386-pc + basic_os=msdos + ;; + msys) + basic_machine=i686-pc + basic_os=msys + ;; + mvs) + basic_machine=i370-ibm + basic_os=mvs + ;; + nacl) + basic_machine=le32-unknown + basic_os=nacl + ;; + ncr3000) + basic_machine=i486-ncr + basic_os=sysv4 + ;; + netbsd386) + basic_machine=i386-pc + basic_os=netbsd + ;; + netwinder) + basic_machine=armv4l-rebel + basic_os=linux + ;; + news | news700 | news800 | news900) + basic_machine=m68k-sony + basic_os=newsos + ;; + news1000) + basic_machine=m68030-sony + basic_os=newsos + ;; + necv70) + basic_machine=v70-nec + basic_os=sysv + ;; + nh3000) + basic_machine=m68k-harris + basic_os=cxux + ;; + nh[45]000) + basic_machine=m88k-harris + basic_os=cxux + ;; + nindy960) + basic_machine=i960-intel + basic_os=nindy + ;; + mon960) + basic_machine=i960-intel + basic_os=mon960 + ;; + nonstopux) + basic_machine=mips-compaq + basic_os=nonstopux + ;; + os400) + basic_machine=powerpc-ibm + basic_os=os400 + ;; + OSE68000 | ose68000) + basic_machine=m68000-ericsson + basic_os=ose + ;; + os68k) + basic_machine=m68k-none + basic_os=os68k + ;; + paragon) + basic_machine=i860-intel + basic_os=osf + ;; + parisc) + basic_machine=hppa-unknown + basic_os=linux + ;; + psp) + basic_machine=mipsallegrexel-sony + basic_os=psp + ;; + pw32) + basic_machine=i586-unknown + basic_os=pw32 + ;; + rdos | rdos64) + basic_machine=x86_64-pc + basic_os=rdos + ;; + rdos32) + basic_machine=i386-pc + basic_os=rdos + ;; + rom68k) + basic_machine=m68k-rom68k + basic_os=coff + ;; + sa29200) + basic_machine=a29k-amd + basic_os=udi + ;; + sei) + basic_machine=mips-sei + basic_os=seiux + ;; + sequent) + basic_machine=i386-sequent + basic_os= + ;; + sps7) + basic_machine=m68k-bull + basic_os=sysv2 + ;; + st2000) + basic_machine=m68k-tandem + basic_os= + ;; + stratus) + basic_machine=i860-stratus + basic_os=sysv4 + ;; + sun2) + basic_machine=m68000-sun + basic_os= + ;; + sun2os3) + basic_machine=m68000-sun + basic_os=sunos3 + ;; + sun2os4) + basic_machine=m68000-sun + basic_os=sunos4 + ;; + sun3) + basic_machine=m68k-sun + basic_os= + ;; + sun3os3) + basic_machine=m68k-sun + basic_os=sunos3 + ;; + sun3os4) + basic_machine=m68k-sun + basic_os=sunos4 + ;; + sun4) + basic_machine=sparc-sun + basic_os= + ;; + sun4os3) + basic_machine=sparc-sun + basic_os=sunos3 + ;; + sun4os4) + basic_machine=sparc-sun + basic_os=sunos4 + ;; + sun4sol2) + basic_machine=sparc-sun + basic_os=solaris2 + ;; + sun386 | sun386i | roadrunner) + basic_machine=i386-sun + basic_os= + ;; + sv1) + basic_machine=sv1-cray + basic_os=unicos + ;; + symmetry) + basic_machine=i386-sequent + basic_os=dynix + ;; + t3e) + basic_machine=alphaev5-cray + basic_os=unicos + ;; + t90) + basic_machine=t90-cray + basic_os=unicos + ;; + toad1) + basic_machine=pdp10-xkl + basic_os=tops20 + ;; + tpf) + basic_machine=s390x-ibm + basic_os=tpf + ;; + udi29k) + basic_machine=a29k-amd + basic_os=udi + ;; + ultra3) + basic_machine=a29k-nyu + basic_os=sym1 + ;; + v810 | necv810) + basic_machine=v810-nec + basic_os=none + ;; + vaxv) + basic_machine=vax-dec + basic_os=sysv + ;; + vms) + basic_machine=vax-dec + basic_os=vms + ;; + vsta) + basic_machine=i386-pc + basic_os=vsta + ;; + vxworks960) + basic_machine=i960-wrs + basic_os=vxworks + ;; + vxworks68) + basic_machine=m68k-wrs + basic_os=vxworks + ;; + vxworks29k) + basic_machine=a29k-wrs + basic_os=vxworks + ;; + xbox) + basic_machine=i686-pc + basic_os=mingw32 + ;; + ymp) + basic_machine=ymp-cray + basic_os=unicos + ;; + *) + basic_machine=$1 + basic_os= + ;; + esac + ;; +esac + +# Decode 1-component or ad-hoc basic machines +case $basic_machine in + # Here we handle the default manufacturer of certain CPU types. It is in + # some cases the only manufacturer, in others, it is the most popular. + w89k) + cpu=hppa1.1 + vendor=winbond + ;; + op50n) + cpu=hppa1.1 + vendor=oki + ;; + op60c) + cpu=hppa1.1 + vendor=oki + ;; + ibm*) + cpu=i370 + vendor=ibm + ;; + orion105) + cpu=clipper + vendor=highlevel + ;; + mac | mpw | mac-mpw) + cpu=m68k + vendor=apple + ;; + pmac | pmac-mpw) + cpu=powerpc + vendor=apple + ;; + + # Recognize the various machine names and aliases which stand + # for a CPU type and a company and sometimes even an OS. + 3b1 | 7300 | 7300-att | att-7300 | pc7300 | safari | unixpc) + cpu=m68000 + vendor=att + ;; + 3b*) + cpu=we32k + vendor=att + ;; + bluegene*) + cpu=powerpc + vendor=ibm + basic_os=cnk + ;; + decsystem10* | dec10*) + cpu=pdp10 + vendor=dec + basic_os=tops10 + ;; + decsystem20* | dec20*) + cpu=pdp10 + vendor=dec + basic_os=tops20 + ;; + delta | 3300 | motorola-3300 | motorola-delta \ + | 3300-motorola | delta-motorola) + cpu=m68k + vendor=motorola + ;; + dpx2*) + cpu=m68k + vendor=bull + basic_os=sysv3 + ;; + encore | umax | mmax) + cpu=ns32k + vendor=encore + ;; + elxsi) + cpu=elxsi + vendor=elxsi + basic_os=${basic_os:-bsd} + ;; + fx2800) + cpu=i860 + vendor=alliant + ;; + genix) + cpu=ns32k + vendor=ns + ;; + h3050r* | hiux*) + cpu=hppa1.1 + vendor=hitachi + basic_os=hiuxwe2 + ;; + hp3k9[0-9][0-9] | hp9[0-9][0-9]) + cpu=hppa1.0 + vendor=hp + ;; + hp9k2[0-9][0-9] | hp9k31[0-9]) + cpu=m68000 + vendor=hp + ;; + hp9k3[2-9][0-9]) + cpu=m68k + vendor=hp + ;; + hp9k6[0-9][0-9] | hp6[0-9][0-9]) + cpu=hppa1.0 + vendor=hp + ;; + hp9k7[0-79][0-9] | hp7[0-79][0-9]) + cpu=hppa1.1 + vendor=hp + ;; + hp9k78[0-9] | hp78[0-9]) + # FIXME: really hppa2.0-hp + cpu=hppa1.1 + vendor=hp + ;; + hp9k8[67]1 | hp8[67]1 | hp9k80[24] | hp80[24] | hp9k8[78]9 | hp8[78]9 | hp9k893 | hp893) + # FIXME: really hppa2.0-hp + cpu=hppa1.1 + vendor=hp + ;; + hp9k8[0-9][13679] | hp8[0-9][13679]) + cpu=hppa1.1 + vendor=hp + ;; + hp9k8[0-9][0-9] | hp8[0-9][0-9]) + cpu=hppa1.0 + vendor=hp + ;; + i*86v32) + cpu=$(echo "$1" | sed -e 's/86.*/86/') + vendor=pc + basic_os=sysv32 + ;; + i*86v4*) + cpu=$(echo "$1" | sed -e 's/86.*/86/') + vendor=pc + basic_os=sysv4 + ;; + i*86v) + cpu=$(echo "$1" | sed -e 's/86.*/86/') + vendor=pc + basic_os=sysv + ;; + i*86sol2) + cpu=$(echo "$1" | sed -e 's/86.*/86/') + vendor=pc + basic_os=solaris2 + ;; + j90 | j90-cray) + cpu=j90 + vendor=cray + basic_os=${basic_os:-unicos} + ;; + iris | iris4d) + cpu=mips + vendor=sgi + case $basic_os in + irix*) + ;; + *) + basic_os=irix4 + ;; + esac + ;; + miniframe) + cpu=m68000 + vendor=convergent + ;; + *mint | mint[0-9]* | *MiNT | *MiNT[0-9]*) + cpu=m68k + vendor=atari + basic_os=mint + ;; + news-3600 | risc-news) + cpu=mips + vendor=sony + basic_os=newsos + ;; + next | m*-next) + cpu=m68k + vendor=next + case $basic_os in + openstep*) + ;; + nextstep*) + ;; + ns2*) + basic_os=nextstep2 + ;; + *) + basic_os=nextstep3 + ;; + esac + ;; + np1) + cpu=np1 + vendor=gould + ;; + op50n-* | op60c-*) + cpu=hppa1.1 + vendor=oki + basic_os=proelf + ;; + pa-hitachi) + cpu=hppa1.1 + vendor=hitachi + basic_os=hiuxwe2 + ;; + pbd) + cpu=sparc + vendor=tti + ;; + pbb) + cpu=m68k + vendor=tti + ;; + pc532) + cpu=ns32k + vendor=pc532 + ;; + pn) + cpu=pn + vendor=gould + ;; + power) + cpu=power + vendor=ibm + ;; + ps2) + cpu=i386 + vendor=ibm + ;; + rm[46]00) + cpu=mips + vendor=siemens + ;; + rtpc | rtpc-*) + cpu=romp + vendor=ibm + ;; + sde) + cpu=mipsisa32 + vendor=sde + basic_os=${basic_os:-elf} + ;; + simso-wrs) + cpu=sparclite + vendor=wrs + basic_os=vxworks + ;; + tower | tower-32) + cpu=m68k + vendor=ncr + ;; + vpp*|vx|vx-*) + cpu=f301 + vendor=fujitsu + ;; + w65) + cpu=w65 + vendor=wdc + ;; + w89k-*) + cpu=hppa1.1 + vendor=winbond + basic_os=proelf + ;; + none) + cpu=none + vendor=none + ;; + leon|leon[3-9]) + cpu=sparc + vendor=$basic_machine + ;; + leon-*|leon[3-9]-*) + cpu=sparc + vendor=$(echo "$basic_machine" | sed 's/-.*//') + ;; + + *-*) + # shellcheck disable=SC2162 + IFS="-" read cpu vendor <&2 + exit 1 + ;; + esac + ;; +esac + +# Here we canonicalize certain aliases for manufacturers. +case $vendor in + digital*) + vendor=dec + ;; + commodore*) + vendor=cbm + ;; + *) + ;; +esac + +# Decode manufacturer-specific aliases for certain operating systems. + +if test x$basic_os != x +then + +# First recognize some ad-hoc caes, or perhaps split kernel-os, or else just +# set os. +case $basic_os in + gnu/linux*) + kernel=linux + os=$(echo $basic_os | sed -e 's|gnu/linux|gnu|') + ;; + os2-emx) + kernel=os2 + os=$(echo $basic_os | sed -e 's|os2-emx|emx|') + ;; + nto-qnx*) + kernel=nto + os=$(echo $basic_os | sed -e 's|nto-qnx|qnx|') + ;; + *-*) + # shellcheck disable=SC2162 + IFS="-" read kernel os <&2 + exit 1 + ;; +esac + +# As a final step for OS-related things, validate the OS-kernel combination +# (given a valid OS), if there is a kernel. +case $kernel-$os in + linux-gnu* | linux-dietlibc* | linux-android* | linux-newlib* | linux-musl* | linux-uclibc* ) + ;; + uclinux-uclibc* ) + ;; + -dietlibc* | -newlib* | -musl* | -uclibc* ) + # These are just libc implementations, not actual OSes, and thus + # require a kernel. + echo "Invalid configuration \`$1': libc \`$os' needs explicit kernel." 1>&2 + exit 1 + ;; + kfreebsd*-gnu* | kopensolaris*-gnu*) + ;; + vxworks-simlinux | vxworks-simwindows | vxworks-spe) + ;; + nto-qnx*) + ;; + os2-emx) + ;; + *-eabi* | *-gnueabi*) + ;; + -*) + # Blank kernel with real OS is always fine. + ;; + *-*) + echo "Invalid configuration \`$1': Kernel \`$kernel' not known to work with OS \`$os'." 1>&2 + exit 1 + ;; +esac + +# Here we handle the case where we know the os, and the CPU type, but not the +# manufacturer. We pick the logical manufacturer. +case $vendor in + unknown) + case $cpu-$os in + *-riscix*) + vendor=acorn + ;; + *-sunos*) + vendor=sun + ;; + *-cnk* | *-aix*) + vendor=ibm + ;; + *-beos*) + vendor=be + ;; + *-hpux*) + vendor=hp + ;; + *-mpeix*) + vendor=hp + ;; + *-hiux*) + vendor=hitachi + ;; + *-unos*) + vendor=crds + ;; + *-dgux*) + vendor=dg + ;; + *-luna*) + vendor=omron + ;; + *-genix*) + vendor=ns + ;; + *-clix*) + vendor=intergraph + ;; + *-mvs* | *-opened*) + vendor=ibm + ;; + *-os400*) + vendor=ibm + ;; + s390-* | s390x-*) + vendor=ibm + ;; + *-ptx*) + vendor=sequent + ;; + *-tpf*) + vendor=ibm + ;; + *-vxsim* | *-vxworks* | *-windiss*) + vendor=wrs + ;; + *-aux*) + vendor=apple + ;; + *-hms*) + vendor=hitachi + ;; + *-mpw* | *-macos*) + vendor=apple + ;; + *-*mint | *-mint[0-9]* | *-*MiNT | *-MiNT[0-9]*) + vendor=atari + ;; + *-vos*) + vendor=stratus + ;; + esac + ;; +esac + +echo "$cpu-$vendor-${kernel:+$kernel-}$os" +exit + +# Local variables: +# eval: (add-hook 'before-save-hook 'time-stamp) +# time-stamp-start: "timestamp='" +# time-stamp-format: "%:y-%02m-%02d" +# time-stamp-end: "'" +# End: diff --git a/third_party/sox/patch/libmad.patch b/third_party/sox/patch/libmad.patch new file mode 100644 index 0000000000000000000000000000000000000000..a805787831f48ecde0eebc9468440ee179f55c75 --- /dev/null +++ b/third_party/sox/patch/libmad.patch @@ -0,0 +1,86 @@ +See the followings for the origin of this patch +http://www.linuxfromscratch.org/blfs/view/svn/multimedia/libmad.html +http://www.linuxfromscratch.org/patches/blfs/svn/libmad-0.15.1b-fixes-1.patch +--- src/libmad/configure 2004-02-05 09:34:07.000000000 +0000 ++++ src/libmad/configure.new 2020-06-30 21:10:28.528018931 +0000 +@@ -19083,71 +19083,7 @@ + + if test "$GCC" = yes + then +- if test -z "$arch" +- then +- case "$host" in +- i386-*) ;; +- i?86-*) arch="-march=i486" ;; +- arm*-empeg-*) arch="-march=armv4 -mtune=strongarm1100" ;; +- armv4*-*) arch="-march=armv4 -mtune=strongarm" ;; +- powerpc-*) ;; +- mips*-agenda-*) arch="-mcpu=vr4100" ;; +- mips*-luxsonor-*) arch="-mips1 -mcpu=r3000 -Wa,-m4010" ;; +- esac +- fi +- +- case "$optimize" in +- -O|"-O "*) +- optimize="-O" +- optimize="$optimize -fforce-mem" +- optimize="$optimize -fforce-addr" +- : #x optimize="$optimize -finline-functions" +- : #- optimize="$optimize -fstrength-reduce" +- optimize="$optimize -fthread-jumps" +- optimize="$optimize -fcse-follow-jumps" +- optimize="$optimize -fcse-skip-blocks" +- : #x optimize="$optimize -frerun-cse-after-loop" +- : #x optimize="$optimize -frerun-loop-opt" +- : #x optimize="$optimize -fgcse" +- optimize="$optimize -fexpensive-optimizations" +- optimize="$optimize -fregmove" +- : #* optimize="$optimize -fdelayed-branch" +- : #x optimize="$optimize -fschedule-insns" +- optimize="$optimize -fschedule-insns2" +- : #? optimize="$optimize -ffunction-sections" +- : #? optimize="$optimize -fcaller-saves" +- : #> optimize="$optimize -funroll-loops" +- : #> optimize="$optimize -funroll-all-loops" +- : #x optimize="$optimize -fmove-all-movables" +- : #x optimize="$optimize -freduce-all-givs" +- : #? optimize="$optimize -fstrict-aliasing" +- : #* optimize="$optimize -fstructure-noalias" +- +- case "$host" in +- arm*-*) +- optimize="$optimize -fstrength-reduce" +- ;; +- mips*-*) +- optimize="$optimize -fstrength-reduce" +- optimize="$optimize -finline-functions" +- ;; +- i?86-*) +- optimize="$optimize -fstrength-reduce" +- ;; +- powerpc-apple-*) +- # this triggers an internal compiler error with gcc2 +- : #optimize="$optimize -fstrength-reduce" +- +- # this is really only beneficial with gcc3 +- : #optimize="$optimize -finline-functions" +- ;; +- *) +- # this sometimes provokes bugs in gcc 2.95.2 +- : #optimize="$optimize -fstrength-reduce" +- ;; +- esac +- ;; +- esac ++ optimize="-O2" + fi + + case "$host" in +@@ -21497,6 +21433,7 @@ + then + case "$host" in + i?86-*) FPM="INTEL" ;; ++ x86_64*) FPM="64BIT" ;; + arm*-*) FPM="ARM" ;; + mips*-*) FPM="MIPS" ;; + sparc*-*) FPM="SPARC" ;; diff --git a/third_party/sox/patch/sox.patch b/third_party/sox/patch/sox.patch new file mode 100644 index 0000000000000000000000000000000000000000..fe8df945c078045f58dc661a5a02d8c5f38599ca --- /dev/null +++ b/third_party/sox/patch/sox.patch @@ -0,0 +1,16 @@ +See https://github.com/pytorch/audio/pull/1297 +diff -ru sox/src/formats.c sox/src/formats.c +--- sox/src/formats.c 2014-10-26 19:55:50.000000000 -0700 ++++ sox/src/formats.c 2021-02-22 16:01:02.833144070 -0800 +@@ -333,6 +333,10 @@ + assert(ft); + if (!ft->fp) + return sox_false; +- fstat(fileno((FILE*)ft->fp), &st); ++ int fd = fileno((FILE*)ft->fp); ++ if (fd < 0) ++ return sox_false; ++ if (fstat(fd, &st) < 0) ++ return sox_false; + return ((st.st_mode & S_IFMT) == S_IFREG); + } diff --git a/torchaudio/__init__.py b/torchaudio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d492dc7fe2790ab74359fb0caebca7aeba7cfc42 --- /dev/null +++ b/torchaudio/__init__.py @@ -0,0 +1,38 @@ +from torchaudio import _extension # noqa: F401 +from torchaudio import ( + compliance, + datasets, + functional, + models, + pipelines, + kaldi_io, + utils, + sox_effects, + transforms, +) + +from torchaudio.backend import ( + list_audio_backends, + get_audio_backend, + set_audio_backend, +) + +try: + from .version import __version__, git_version # noqa: F401 +except ImportError: + pass + +__all__ = [ + 'compliance', + 'datasets', + 'functional', + 'models', + 'pipelines', + 'kaldi_io', + 'utils', + 'sox_effects', + 'transforms', + 'list_audio_backends', + 'get_audio_backend', + 'set_audio_backend', +] diff --git a/torchaudio/_extension.py b/torchaudio/_extension.py new file mode 100644 index 0000000000000000000000000000000000000000..6bb217c6844c10b97f9830c0234da6aba8a4148a --- /dev/null +++ b/torchaudio/_extension.py @@ -0,0 +1,27 @@ +import os +import warnings +from pathlib import Path + +import torch +from torchaudio._internal import module_utils as _mod_utils # noqa: F401 + + +def _init_extension(): + if not _mod_utils.is_module_available('torchaudio._torchaudio'): + warnings.warn('torchaudio C++ extension is not available.') + return + + suffix = 'pyd' if os.name == 'nt' else 'so' + path = Path(__file__).parent / 'lib' / f'libtorchaudio.{suffix}' + # In case `torchaudio` is deployed with `pex` format, this file does not exist. + # In this case, we expect that `libtorchaudio` is available somewhere + # in the search path of dynamic loading mechanism, and importing `_torchaudio`, + # which depends on `libtorchaudio` and dynamic loader will handle it for us. + if path.exists(): + torch.ops.load_library(path) + torch.classes.load_library(path) + # This import is for initializing the methods registered via PyBind11 + from torchaudio import _torchaudio # noqa + + +_init_extension() diff --git a/torchaudio/_internal/__init__.py b/torchaudio/_internal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/torchaudio/_internal/module_utils.py b/torchaudio/_internal/module_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eecc7cf6d5fe9bcaece5fe60a2e134fd18456ce1 --- /dev/null +++ b/torchaudio/_internal/module_utils.py @@ -0,0 +1,125 @@ +import warnings +import importlib.util +from typing import Optional +from functools import wraps + +import torch + + +def is_module_available(*modules: str) -> bool: + r"""Returns if a top-level module with :attr:`name` exists *without** + importing it. This is generally safer than try-catch block around a + `import X`. It avoids third party libraries breaking assumptions of some of + our tests, e.g., setting multiprocessing start method when imported + (see librosa/#747, torchvision/#544). + """ + return all(importlib.util.find_spec(m) is not None for m in modules) + + +def requires_module(*modules: str): + """Decorate function to give error message if invoked without required optional modules. + + This decorator is to give better error message to users rather + than raising ``NameError: name 'module' is not defined`` at random places. + """ + missing = [m for m in modules if not is_module_available(m)] + + if not missing: + # fall through. If all the modules are available, no need to decorate + def decorator(func): + return func + else: + req = f'module: {missing[0]}' if len(missing) == 1 else f'modules: {missing}' + + def decorator(func): + @wraps(func) + def wrapped(*args, **kwargs): + raise RuntimeError(f'{func.__module__}.{func.__name__} requires {req}') + return wrapped + return decorator + + +def deprecated(direction: str, version: Optional[str] = None): + """Decorator to add deprecation message + + Args: + direction (str): Migration steps to be given to users. + version (str or int): The version when the object will be removed + """ + def decorator(func): + + @wraps(func) + def wrapped(*args, **kwargs): + message = ( + f'{func.__module__}.{func.__name__} has been deprecated ' + f'and will be removed from {"future" if version is None else version} release. ' + f'{direction}') + warnings.warn(message, stacklevel=2) + return func(*args, **kwargs) + return wrapped + return decorator + + +def is_kaldi_available(): + return is_module_available('torchaudio._torchaudio') and torch.ops.torchaudio.is_kaldi_available() + + +def requires_kaldi(): + if is_kaldi_available(): + def decorator(func): + return func + else: + def decorator(func): + @wraps(func) + def wrapped(*args, **kwargs): + raise RuntimeError(f'{func.__module__}.{func.__name__} requires kaldi') + return wrapped + return decorator + + +def _check_soundfile_importable(): + if not is_module_available('soundfile'): + return False + try: + import soundfile # noqa: F401 + return True + except Exception: + warnings.warn("Failed to import soundfile. 'soundfile' backend is not available.") + return False + + +_is_soundfile_importable = _check_soundfile_importable() + + +def is_soundfile_available(): + return _is_soundfile_importable + + +def requires_soundfile(): + if is_soundfile_available(): + def decorator(func): + return func + else: + def decorator(func): + @wraps(func) + def wrapped(*args, **kwargs): + raise RuntimeError(f'{func.__module__}.{func.__name__} requires soundfile') + return wrapped + return decorator + + +def is_sox_available(): + return is_module_available('torchaudio._torchaudio') and torch.ops.torchaudio.is_sox_available() + + +def requires_sox(): + if is_sox_available(): + def decorator(func): + return func + else: + def decorator(func): + @wraps(func) + def wrapped(*args, **kwargs): + raise RuntimeError(f'{func.__module__}.{func.__name__} requires sox') + return wrapped + return decorator diff --git a/torchaudio/backend/__init__.py b/torchaudio/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c3fdf0b439811ec524d02d991b6272b7c07ae01c --- /dev/null +++ b/torchaudio/backend/__init__.py @@ -0,0 +1,10 @@ +# flake8: noqa +from . import utils +from .utils import ( + list_audio_backends, + get_audio_backend, + set_audio_backend, +) + + +utils._init_audio_backend() diff --git a/torchaudio/backend/common.py b/torchaudio/backend/common.py new file mode 100644 index 0000000000000000000000000000000000000000..f944b23d74bbb34e354ab628fffc7238c49e0b2a --- /dev/null +++ b/torchaudio/backend/common.py @@ -0,0 +1,51 @@ +class AudioMetaData: + """Return type of ``torchaudio.info`` function. + + This class is used by :ref:`"sox_io" backend` and + :ref:`"soundfile" backend with the new interface`. + + :ivar int sample_rate: Sample rate + :ivar int num_frames: The number of frames + :ivar int num_channels: The number of channels + :ivar int bits_per_sample: The number of bits per sample. This is 0 for lossy formats, + or when it cannot be accurately inferred. + :ivar str encoding: Audio encoding + The values encoding can take are one of the following: + + * ``PCM_S``: Signed integer linear PCM + * ``PCM_U``: Unsigned integer linear PCM + * ``PCM_F``: Floating point linear PCM + * ``FLAC``: Flac, Free Lossless Audio Codec + * ``ULAW``: Mu-law + * ``ALAW``: A-law + * ``MP3`` : MP3, MPEG-1 Audio Layer III + * ``VORBIS``: OGG Vorbis + * ``AMR_WB``: Adaptive Multi-Rate + * ``AMR_NB``: Adaptive Multi-Rate Wideband + * ``OPUS``: Opus + * ``UNKNOWN`` : None of above + """ + def __init__( + self, + sample_rate: int, + num_frames: int, + num_channels: int, + bits_per_sample: int, + encoding: str, + ): + self.sample_rate = sample_rate + self.num_frames = num_frames + self.num_channels = num_channels + self.bits_per_sample = bits_per_sample + self.encoding = encoding + + def __str__(self): + return ( + f"AudioMetaData(" + f"sample_rate={self.sample_rate}, " + f"num_frames={self.num_frames}, " + f"num_channels={self.num_channels}, " + f"bits_per_sample={self.bits_per_sample}, " + f"encoding={self.encoding}" + f")" + ) diff --git a/torchaudio/backend/no_backend.py b/torchaudio/backend/no_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..7160c89f69b0672fa0a044260275bd208eb067f5 --- /dev/null +++ b/torchaudio/backend/no_backend.py @@ -0,0 +1,22 @@ +from pathlib import Path +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor + + +def load(filepath: Union[str, Path], + out: Optional[Tensor] = None, + normalization: Union[bool, float, Callable] = True, + channels_first: bool = True, + num_frames: int = 0, + offset: int = 0, + filetype: Optional[str] = None) -> Tuple[Tensor, int]: + raise RuntimeError('No audio I/O backend is available.') + + +def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None: + raise RuntimeError('No audio I/O backend is available.') + + +def info(filepath: str) -> None: + raise RuntimeError('No audio I/O backend is available.') diff --git a/torchaudio/backend/soundfile_backend.py b/torchaudio/backend/soundfile_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..8b0b27c5cd5463fe28a651cd694f846e9549a707 --- /dev/null +++ b/torchaudio/backend/soundfile_backend.py @@ -0,0 +1,433 @@ +"""The new soundfile backend which will become default in 0.8.0 onward""" +from typing import Tuple, Optional +import warnings + +import torch +from torchaudio._internal import module_utils as _mod_utils +from .common import AudioMetaData + + +if _mod_utils.is_soundfile_available(): + import soundfile + +# Mapping from soundfile subtype to number of bits per sample. +# This is mostly heuristical and the value is set to 0 when it is irrelevant +# (lossy formats) or when it can't be inferred. +# For ADPCM (and G72X) subtypes, it's hard to infer the bit depth because it's not part of the standard: +# According to https://en.wikipedia.org/wiki/Adaptive_differential_pulse-code_modulation#In_telephony, +# the default seems to be 8 bits but it can be compressed further to 4 bits. +# The dict is inspired from +# https://github.com/bastibe/python-soundfile/blob/744efb4b01abc72498a96b09115b42a4cabd85e4/soundfile.py#L66-L94 +_SUBTYPE_TO_BITS_PER_SAMPLE = { + 'PCM_S8': 8, # Signed 8 bit data + 'PCM_16': 16, # Signed 16 bit data + 'PCM_24': 24, # Signed 24 bit data + 'PCM_32': 32, # Signed 32 bit data + 'PCM_U8': 8, # Unsigned 8 bit data (WAV and RAW only) + 'FLOAT': 32, # 32 bit float data + 'DOUBLE': 64, # 64 bit float data + 'ULAW': 8, # U-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types + 'ALAW': 8, # A-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types + 'IMA_ADPCM': 0, # IMA ADPCM. + 'MS_ADPCM': 0, # Microsoft ADPCM. + 'GSM610': 0, # GSM 6.10 encoding. (Wikipedia says 1.625 bit depth?? https://en.wikipedia.org/wiki/Full_Rate) + 'VOX_ADPCM': 0, # OKI / Dialogix ADPCM + 'G721_32': 0, # 32kbs G721 ADPCM encoding. + 'G723_24': 0, # 24kbs G723 ADPCM encoding. + 'G723_40': 0, # 40kbs G723 ADPCM encoding. + 'DWVW_12': 12, # 12 bit Delta Width Variable Word encoding. + 'DWVW_16': 16, # 16 bit Delta Width Variable Word encoding. + 'DWVW_24': 24, # 24 bit Delta Width Variable Word encoding. + 'DWVW_N': 0, # N bit Delta Width Variable Word encoding. + 'DPCM_8': 8, # 8 bit differential PCM (XI only) + 'DPCM_16': 16, # 16 bit differential PCM (XI only) + 'VORBIS': 0, # Xiph Vorbis encoding. (lossy) + 'ALAC_16': 16, # Apple Lossless Audio Codec (16 bit). + 'ALAC_20': 20, # Apple Lossless Audio Codec (20 bit). + 'ALAC_24': 24, # Apple Lossless Audio Codec (24 bit). + 'ALAC_32': 32, # Apple Lossless Audio Codec (32 bit). +} + + +def _get_bit_depth(subtype): + if subtype not in _SUBTYPE_TO_BITS_PER_SAMPLE: + warnings.warn( + f"The {subtype} subtype is unknown to TorchAudio. As a result, the bits_per_sample " + "attribute will be set to 0. If you are seeing this warning, please " + "report by opening an issue on github (after checking for existing/closed ones). " + "You may otherwise ignore this warning." + ) + return _SUBTYPE_TO_BITS_PER_SAMPLE.get(subtype, 0) + + +_SUBTYPE_TO_ENCODING = { + 'PCM_S8': 'PCM_S', + 'PCM_16': 'PCM_S', + 'PCM_24': 'PCM_S', + 'PCM_32': 'PCM_S', + 'PCM_U8': 'PCM_U', + 'FLOAT': 'PCM_F', + 'DOUBLE': 'PCM_F', + 'ULAW': 'ULAW', + 'ALAW': 'ALAW', + 'VORBIS': 'VORBIS', +} + + +def _get_encoding(format: str, subtype: str): + if format == 'FLAC': + return 'FLAC' + return _SUBTYPE_TO_ENCODING.get(subtype, 'UNKNOWN') + + +@_mod_utils.requires_soundfile() +def info(filepath: str, format: Optional[str] = None) -> AudioMetaData: + """Get signal information of an audio file. + + Note: + ``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts + ``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend, + which has a restriction on type annotation due to TorchScript compiler compatiblity. + + Args: + filepath (path-like object or file-like object): + Source of audio data. + format (str or None, optional): + Not used. PySoundFile does not accept format hint. + + Returns: + AudioMetaData: meta data of the given audio. + + """ + sinfo = soundfile.info(filepath) + return AudioMetaData( + sinfo.samplerate, + sinfo.frames, + sinfo.channels, + bits_per_sample=_get_bit_depth(sinfo.subtype), + encoding=_get_encoding(sinfo.format, sinfo.subtype), + ) + + +_SUBTYPE2DTYPE = { + "PCM_S8": "int8", + "PCM_U8": "uint8", + "PCM_16": "int16", + "PCM_32": "int32", + "FLOAT": "float32", + "DOUBLE": "float64", +} + + +@_mod_utils.requires_soundfile() +def load( + filepath: str, + frame_offset: int = 0, + num_frames: int = -1, + normalize: bool = True, + channels_first: bool = True, + format: Optional[str] = None, +) -> Tuple[torch.Tensor, int]: + """Load audio data from file. + + Note: + The formats this function can handle depend on the soundfile installation. + This function is tested on the following formats; + + * WAV + + * 32-bit floating-point + * 32-bit signed integer + * 16-bit signed integer + * 8-bit unsigned integer + + * FLAC + * OGG/VORBIS + * SPHERE + + By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with + ``float32`` dtype and the shape of `[channel, time]`. + The samples are normalized to fit in the range of ``[-1.0, 1.0]``. + + When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit + signed integer and 8-bit unsigned integer (24-bit signed integer is not supported), + by providing ``normalize=False``, this function can return integer Tensor, where the samples + are expressed within the whole range of the corresponding dtype, that is, ``int32`` tensor + for 32-bit signed PCM, ``int16`` for 16-bit signed PCM and ``uint8`` for 8-bit unsigned PCM. + + ``normalize`` parameter has no effect on 32-bit floating-point WAV and other formats, such as + ``flac`` and ``mp3``. + For these formats, this function always returns ``float32`` Tensor with values normalized to + ``[-1.0, 1.0]``. + + Note: + ``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts + ``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend, + which has a restriction on type annotation due to TorchScript compiler compatiblity. + + Args: + filepath (path-like object or file-like object): + Source of audio data. + frame_offset (int, optional): + Number of frames to skip before start reading data. + num_frames (int, optional): + Maximum number of frames to read. ``-1`` reads all the remaining samples, + starting from ``frame_offset``. + This function may return the less number of frames if there is not enough + frames in the given file. + normalize (bool, optional): + When ``True``, this function always return ``float32``, and sample values are + normalized to ``[-1.0, 1.0]``. + If input file is integer WAV, giving ``False`` will change the resulting Tensor type to + integer type. + This argument has no effect for formats other than integer WAV type. + channels_first (bool, optional): + When True, the returned Tensor has dimension `[channel, time]`. + Otherwise, the returned Tensor's dimension is `[time, channel]`. + format (str or None, optional): + Not used. PySoundFile does not accept format hint. + + Returns: + (torch.Tensor, int): Resulting Tensor and sample rate. + If the input file has integer wav format and normalization is off, then it has + integer type, else ``float32`` type. If ``channels_first=True``, it has + `[channel, time]` else `[time, channel]`. + """ + with soundfile.SoundFile(filepath, "r") as file_: + if file_.format != "WAV" or normalize: + dtype = "float32" + elif file_.subtype not in _SUBTYPE2DTYPE: + raise ValueError(f"Unsupported subtype: {file_.subtype}") + else: + dtype = _SUBTYPE2DTYPE[file_.subtype] + + frames = file_._prepare_read(frame_offset, None, num_frames) + waveform = file_.read(frames, dtype, always_2d=True) + sample_rate = file_.samplerate + + waveform = torch.from_numpy(waveform) + if channels_first: + waveform = waveform.t() + return waveform, sample_rate + + +def _get_subtype_for_wav( + dtype: torch.dtype, + encoding: str, + bits_per_sample: int): + if not encoding: + if not bits_per_sample: + subtype = { + torch.uint8: "PCM_U8", + torch.int16: "PCM_16", + torch.int32: "PCM_32", + torch.float32: "FLOAT", + torch.float64: "DOUBLE", + }.get(dtype) + if not subtype: + raise ValueError(f"Unsupported dtype for wav: {dtype}") + return subtype + if bits_per_sample == 8: + return "PCM_U8" + return f"PCM_{bits_per_sample}" + if encoding == "PCM_S": + if not bits_per_sample: + return "PCM_32" + if bits_per_sample == 8: + raise ValueError("wav does not support 8-bit signed PCM encoding.") + return f"PCM_{bits_per_sample}" + if encoding == "PCM_U": + if bits_per_sample in (None, 8): + return "PCM_U8" + raise ValueError("wav only supports 8-bit unsigned PCM encoding.") + if encoding == "PCM_F": + if bits_per_sample in (None, 32): + return "FLOAT" + if bits_per_sample == 64: + return "DOUBLE" + raise ValueError("wav only supports 32/64-bit float PCM encoding.") + if encoding == "ULAW": + if bits_per_sample in (None, 8): + return "ULAW" + raise ValueError("wav only supports 8-bit mu-law encoding.") + if encoding == "ALAW": + if bits_per_sample in (None, 8): + return "ALAW" + raise ValueError("wav only supports 8-bit a-law encoding.") + raise ValueError(f"wav does not support {encoding}.") + + +def _get_subtype_for_sphere(encoding: str, bits_per_sample: int): + if encoding in (None, "PCM_S"): + return f"PCM_{bits_per_sample}" if bits_per_sample else "PCM_32" + if encoding in ("PCM_U", "PCM_F"): + raise ValueError(f"sph does not support {encoding} encoding.") + if encoding == "ULAW": + if bits_per_sample in (None, 8): + return "ULAW" + raise ValueError("sph only supports 8-bit for mu-law encoding.") + if encoding == "ALAW": + return "ALAW" + raise ValueError(f"sph does not support {encoding}.") + + +def _get_subtype( + dtype: torch.dtype, + format: str, + encoding: str, + bits_per_sample: int): + if format == "wav": + return _get_subtype_for_wav(dtype, encoding, bits_per_sample) + if format == "flac": + if encoding: + raise ValueError("flac does not support encoding.") + if not bits_per_sample: + return "PCM_16" + if bits_per_sample > 24: + raise ValueError("flac does not support bits_per_sample > 24.") + return "PCM_S8" if bits_per_sample == 8 else f"PCM_{bits_per_sample}" + if format in ("ogg", "vorbis"): + if encoding or bits_per_sample: + raise ValueError( + "ogg/vorbis does not support encoding/bits_per_sample.") + return "VORBIS" + if format == "sph": + return _get_subtype_for_sphere(encoding, bits_per_sample) + if format in ("nis", "nist"): + return "PCM_16" + raise ValueError(f"Unsupported format: {format}") + + +@_mod_utils.requires_soundfile() +def save( + filepath: str, + src: torch.Tensor, + sample_rate: int, + channels_first: bool = True, + compression: Optional[float] = None, + format: Optional[str] = None, + encoding: Optional[str] = None, + bits_per_sample: Optional[int] = None, +): + """Save audio data to file. + + Note: + The formats this function can handle depend on the soundfile installation. + This function is tested on the following formats; + + * WAV + + * 32-bit floating-point + * 32-bit signed integer + * 16-bit signed integer + * 8-bit unsigned integer + + * FLAC + * OGG/VORBIS + * SPHERE + + Note: + ``filepath`` argument is intentionally annotated as ``str`` only, even though it accepts + ``pathlib.Path`` object as well. This is for the consistency with ``"sox_io"`` backend, + which has a restriction on type annotation due to TorchScript compiler compatiblity. + + Args: + filepath (str or pathlib.Path): Path to audio file. + src (torch.Tensor): Audio data to save. must be 2D tensor. + sample_rate (int): sampling rate + channels_first (bool, optional): If ``True``, the given tensor is interpreted as `[channel, time]`, + otherwise `[time, channel]`. + compression (float of None, optional): Not used. + It is here only for interface compatibility reson with "sox_io" backend. + format (str or None, optional): Override the audio format. + When ``filepath`` argument is path-like object, audio format is + inferred from file extension. If the file extension is missing or + different, you can specify the correct format with this argument. + + When ``filepath`` argument is file-like object, + this argument is required. + + Valid values are ``"wav"``, ``"ogg"``, ``"vorbis"``, + ``"flac"`` and ``"sph"``. + encoding (str or None, optional): Changes the encoding for supported formats. + This argument is effective only for supported formats, sush as + ``"wav"``, ``""flac"`` and ``"sph"``. Valid values are; + + - ``"PCM_S"`` (signed integer Linear PCM) + - ``"PCM_U"`` (unsigned integer Linear PCM) + - ``"PCM_F"`` (floating point PCM) + - ``"ULAW"`` (mu-law) + - ``"ALAW"`` (a-law) + + bits_per_sample (int or None, optional): Changes the bit depth for the + supported formats. + When ``format`` is one of ``"wav"``, ``"flac"`` or ``"sph"``, + you can change the bit depth. + Valid values are ``8``, ``16``, ``24``, ``32`` and ``64``. + + Supported formats/encodings/bit depth/compression are: + + ``"wav"`` + - 32-bit floating-point PCM + - 32-bit signed integer PCM + - 24-bit signed integer PCM + - 16-bit signed integer PCM + - 8-bit unsigned integer PCM + - 8-bit mu-law + - 8-bit a-law + + Note: Default encoding/bit depth is determined by the dtype of + the input Tensor. + + ``"flac"`` + - 8-bit + - 16-bit (default) + - 24-bit + + ``"ogg"``, ``"vorbis"`` + - Doesn't accept changing configuration. + + ``"sph"`` + - 8-bit signed integer PCM + - 16-bit signed integer PCM + - 24-bit signed integer PCM + - 32-bit signed integer PCM (default) + - 8-bit mu-law + - 8-bit a-law + - 16-bit a-law + - 24-bit a-law + - 32-bit a-law + + """ + if src.ndim != 2: + raise ValueError(f"Expected 2D Tensor, got {src.ndim}D.") + if compression is not None: + warnings.warn( + '`save` function of "soundfile" backend does not support "compression" parameter. ' + "The argument is silently ignored." + ) + if hasattr(filepath, 'write'): + if format is None: + raise RuntimeError('`format` is required when saving to file object.') + ext = format.lower() + else: + ext = str(filepath).split(".")[-1].lower() + + if bits_per_sample not in (None, 8, 16, 24, 32, 64): + raise ValueError("Invalid bits_per_sample.") + if bits_per_sample == 24: + warnings.warn("Saving audio with 24 bits per sample might warp samples near -1. " + "Using 16 bits per sample might be able to avoid this.") + subtype = _get_subtype(src.dtype, ext, encoding, bits_per_sample) + + # sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format, + # so we extend the extensions manually here + if ext in ["nis", "nist", "sph"] and format is None: + format = "NIST" + + if channels_first: + src = src.t() + + soundfile.write( + file=filepath, data=src, samplerate=sample_rate, subtype=subtype, format=format + ) diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..b725bcc96802d80ef6585acb6e0c3d57c6793e61 --- /dev/null +++ b/torchaudio/backend/sox_io_backend.py @@ -0,0 +1,317 @@ +import os +from typing import Tuple, Optional + +import torch +from torchaudio._internal import ( + module_utils as _mod_utils, +) + +import torchaudio +from .common import AudioMetaData + + +@_mod_utils.requires_sox() +def info( + filepath: str, + format: Optional[str] = None, +) -> AudioMetaData: + """Get signal information of an audio file. + + Args: + filepath (path-like object or file-like object): + Source of audio data. When the function is not compiled by TorchScript, + (e.g. ``torch.jit.script``), the following types are accepted; + + * ``path-like``: file path + * ``file-like``: Object with ``read(size: int) -> bytes`` method, + which returns byte string of at most ``size`` length. + + When the function is compiled by TorchScript, only ``str`` type is allowed. + + Note: + + * When the input type is file-like object, this function cannot + get the correct length (``num_samples``) for certain formats, + such as ``mp3`` and ``vorbis``. + In this case, the value of ``num_samples`` is ``0``. + * This argument is intentionally annotated as ``str`` only due to + TorchScript compiler compatibility. + + format (str or None, optional): + Override the format detection with the given format. + Providing the argument might help when libsox can not infer the format + from header or extension, + + Returns: + AudioMetaData: Metadata of the given audio. + """ + if not torch.jit.is_scripting(): + if hasattr(filepath, 'read'): + sinfo = torchaudio._torchaudio.get_info_fileobj(filepath, format) + return AudioMetaData(*sinfo) + filepath = os.fspath(filepath) + sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format) + return AudioMetaData(*sinfo) + + +@_mod_utils.requires_sox() +def load( + filepath: str, + frame_offset: int = 0, + num_frames: int = -1, + normalize: bool = True, + channels_first: bool = True, + format: Optional[str] = None, +) -> Tuple[torch.Tensor, int]: + """Load audio data from file. + + Note: + This function can handle all the codecs that underlying libsox can handle, + however it is tested on the following formats; + + * WAV, AMB + + * 32-bit floating-point + * 32-bit signed integer + * 24-bit signed integer + * 16-bit signed integer + * 8-bit unsigned integer (WAV only) + + * MP3 + * FLAC + * OGG/VORBIS + * OPUS + * SPHERE + * AMR-NB + + To load ``MP3``, ``FLAC``, ``OGG/VORBIS``, ``OPUS`` and other codecs ``libsox`` does not + handle natively, your installation of ``torchaudio`` has to be linked to ``libsox`` + and corresponding codec libraries such as ``libmad`` or ``libmp3lame`` etc. + + By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with + ``float32`` dtype and the shape of `[channel, time]`. + The samples are normalized to fit in the range of ``[-1.0, 1.0]``. + + When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit + signed integer, 24-bit signed integer, and 8-bit unsigned integer, by providing ``normalize=False``, + this function can return integer Tensor, where the samples are expressed within the whole range + of the corresponding dtype, that is, ``int32`` tensor for 32-bit signed PCM, + ``int16`` for 16-bit signed PCM and ``uint8`` for 8-bit unsigned PCM. Since torch does not + support ``int24`` dtype, 24-bit signed PCM are converted to ``int32`` tensors. + + ``normalize`` parameter has no effect on 32-bit floating-point WAV and other formats, such as + ``flac`` and ``mp3``. + For these formats, this function always returns ``float32`` Tensor with values normalized to + ``[-1.0, 1.0]``. + + Args: + filepath (path-like object or file-like object): + Source of audio data. When the function is not compiled by TorchScript, + (e.g. ``torch.jit.script``), the following types are accepted; + + * ``path-like``: file path + * ``file-like``: Object with ``read(size: int) -> bytes`` method, + which returns byte string of at most ``size`` length. + + When the function is compiled by TorchScript, only ``str`` type is allowed. + + Note: This argument is intentionally annotated as ``str`` only due to + TorchScript compiler compatibility. + frame_offset (int): + Number of frames to skip before start reading data. + num_frames (int, optional): + Maximum number of frames to read. ``-1`` reads all the remaining samples, + starting from ``frame_offset``. + This function may return the less number of frames if there is not enough + frames in the given file. + normalize (bool, optional): + When ``True``, this function always return ``float32``, and sample values are + normalized to ``[-1.0, 1.0]``. + If input file is integer WAV, giving ``False`` will change the resulting Tensor type to + integer type. + This argument has no effect for formats other than integer WAV type. + channels_first (bool, optional): + When True, the returned Tensor has dimension `[channel, time]`. + Otherwise, the returned Tensor's dimension is `[time, channel]`. + format (str or None, optional): + Override the format detection with the given format. + Providing the argument might help when libsox can not infer the format + from header or extension, + + Returns: + (torch.Tensor, int): Resulting Tensor and sample rate. + If the input file has integer wav format and normalization is off, then it has + integer type, else ``float32`` type. If ``channels_first=True``, it has + `[channel, time]` else `[time, channel]`. + """ + if not torch.jit.is_scripting(): + if hasattr(filepath, 'read'): + return torchaudio._torchaudio.load_audio_fileobj( + filepath, frame_offset, num_frames, normalize, channels_first, format) + filepath = os.fspath(filepath) + return torch.ops.torchaudio.sox_io_load_audio_file( + filepath, frame_offset, num_frames, normalize, channels_first, format) + + +@_mod_utils.requires_sox() +def save( + filepath: str, + src: torch.Tensor, + sample_rate: int, + channels_first: bool = True, + compression: Optional[float] = None, + format: Optional[str] = None, + encoding: Optional[str] = None, + bits_per_sample: Optional[int] = None, +): + """Save audio data to file. + + Args: + filepath (str or pathlib.Path): Path to save file. + This function also handles ``pathlib.Path`` objects, but is annotated + as ``str`` for TorchScript compiler compatibility. + src (torch.Tensor): Audio data to save. must be 2D tensor. + sample_rate (int): sampling rate + channels_first (bool, optional): If ``True``, the given tensor is interpreted as `[channel, time]`, + otherwise `[time, channel]`. + compression (float or None, optional): Used for formats other than WAV. + This corresponds to ``-C`` option of ``sox`` command. + + ``"mp3"`` + Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or + VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``. + + ``"flac"`` + Whole number from ``0`` to ``8``. ``8`` is default and highest compression. + + ``"ogg"``, ``"vorbis"`` + Number from ``-1`` to ``10``; ``-1`` is the highest compression + and lowest quality. Default: ``3``. + + See the detail at http://sox.sourceforge.net/soxformat.html. + format (str or None, optional): Override the audio format. + When ``filepath`` argument is path-like object, audio format is infered from + file extension. If file extension is missing or different, you can specify the + correct format with this argument. + + When ``filepath`` argument is file-like object, this argument is required. + + Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``, + ``"amb"``, ``"flac"``, ``"sph"``, ``"gsm"``, and ``"htk"``. + + encoding (str or None, optional): Changes the encoding for the supported formats. + This argument is effective only for supported formats, such as ``"wav"``, ``""amb"`` + and ``"sph"``. Valid values are; + + - ``"PCM_S"`` (signed integer Linear PCM) + - ``"PCM_U"`` (unsigned integer Linear PCM) + - ``"PCM_F"`` (floating point PCM) + - ``"ULAW"`` (mu-law) + - ``"ALAW"`` (a-law) + + Default values + If not provided, the default value is picked based on ``format`` and ``bits_per_sample``. + + ``"wav"``, ``"amb"`` + - | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the + | Tensor is used to determine the default value. + - ``"PCM_U"`` if dtype is ``uint8`` + - ``"PCM_S"`` if dtype is ``int16`` or ``int32` + - ``"PCM_F"`` if dtype is ``float32`` + + - ``"PCM_U"`` if ``bits_per_sample=8`` + - ``"PCM_S"`` otherwise + + ``"sph"`` format; + - the default value is ``"PCM_S"`` + + bits_per_sample (int or None, optional): Changes the bit depth for the supported formats. + When ``format`` is one of ``"wav"``, ``"flac"``, ``"sph"``, or ``"amb"``, you can change the + bit depth. Valid values are ``8``, ``16``, ``32`` and ``64``. + + Default Value; + If not provided, the default values are picked based on ``format`` and ``"encoding"``; + + ``"wav"``, ``"amb"``; + - | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the + | Tensor is used. + - ``8`` if dtype is ``uint8`` + - ``16`` if dtype is ``int16`` + - ``32`` if dtype is ``int32`` or ``float32`` + + - ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"`` + - ``16`` if ``encoding`` is ``"PCM_S"`` + - ``32`` if ``encoding`` is ``"PCM_F"`` + + ``"flac"`` format; + - the default value is ``24`` + + ``"sph"`` format; + - ``16`` if ``encoding`` is ``"PCM_U"``, ``"PCM_S"``, ``"PCM_F"`` or not provided. + - ``8`` if ``encoding`` is ``"ULAW"`` or ``"ALAW"`` + + ``"amb"`` format; + - ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"`` + - ``16`` if ``encoding`` is ``"PCM_S"`` or not provided. + - ``32`` if ``encoding`` is ``"PCM_F"`` + + Supported formats/encodings/bit depth/compression are; + + ``"wav"``, ``"amb"`` + - 32-bit floating-point PCM + - 32-bit signed integer PCM + - 24-bit signed integer PCM + - 16-bit signed integer PCM + - 8-bit unsigned integer PCM + - 8-bit mu-law + - 8-bit a-law + + Note: Default encoding/bit depth is determined by the dtype of the input Tensor. + + ``"mp3"`` + Fixed bit rate (such as 128kHz) and variable bit rate compression. + Default: VBR with high quality. + + ``"flac"`` + - 8-bit + - 16-bit + - 24-bit (default) + + ``"ogg"``, ``"vorbis"`` + - Different quality level. Default: approx. 112kbps + + ``"sph"`` + - 8-bit signed integer PCM + - 16-bit signed integer PCM + - 24-bit signed integer PCM + - 32-bit signed integer PCM (default) + - 8-bit mu-law + - 8-bit a-law + - 16-bit a-law + - 24-bit a-law + - 32-bit a-law + + ``"amr-nb"`` + Bitrate ranging from 4.75 kbit/s to 12.2 kbit/s. Default: 4.75 kbit/s + + ``"gsm"`` + Lossy Speech Compression, CPU intensive. + + ``"htk"`` + Uses a default single-channel 16-bit PCM format. + + Note: + To save into formats that ``libsox`` does not handle natively, (such as ``"mp3"``, + ``"flac"``, ``"ogg"`` and ``"vorbis"``), your installation of ``torchaudio`` has + to be linked to ``libsox`` and corresponding codec libraries such as ``libmad`` + or ``libmp3lame`` etc. + """ + if not torch.jit.is_scripting(): + if hasattr(filepath, 'write'): + torchaudio._torchaudio.save_audio_fileobj( + filepath, src, sample_rate, channels_first, compression, + format, encoding, bits_per_sample) + return + filepath = os.fspath(filepath) + torch.ops.torchaudio.sox_io_save_audio_file( + filepath, src, sample_rate, channels_first, compression, format, encoding, bits_per_sample) diff --git a/torchaudio/backend/utils.py b/torchaudio/backend/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ab670e0c03076722d7c221edddf73ba43511c773 --- /dev/null +++ b/torchaudio/backend/utils.py @@ -0,0 +1,83 @@ +"""Defines utilities for switching audio backends""" +import warnings +from typing import Optional, List + +import torchaudio +from torchaudio._internal import module_utils as _mod_utils +from . import ( + no_backend, + sox_io_backend, + soundfile_backend, +) + +__all__ = [ + 'list_audio_backends', + 'get_audio_backend', + 'set_audio_backend', +] + + +def list_audio_backends() -> List[str]: + """List available backends + + Returns: + List[str]: The list of available backends. + """ + backends = [] + if _mod_utils.is_module_available('soundfile'): + backends.append('soundfile') + if _mod_utils.is_sox_available(): + backends.append('sox_io') + return backends + + +def set_audio_backend(backend: Optional[str]): + """Set the backend for I/O operation + + Args: + backend (str or None): Name of the backend. + One of ``"sox_io"`` or ``"soundfile"`` based on availability + of the system. If ``None`` is provided the current backend is unassigned. + """ + if backend is not None and backend not in list_audio_backends(): + raise RuntimeError( + f'Backend "{backend}" is not one of ' + f'available backends: {list_audio_backends()}.') + + if backend is None: + module = no_backend + elif backend == 'sox_io': + module = sox_io_backend + elif backend == 'soundfile': + module = soundfile_backend + else: + raise NotImplementedError(f'Unexpected backend "{backend}"') + + for func in ['save', 'load', 'info']: + setattr(torchaudio, func, getattr(module, func)) + + +def _init_audio_backend(): + backends = list_audio_backends() + if 'sox_io' in backends: + set_audio_backend('sox_io') + elif 'soundfile' in backends: + set_audio_backend('soundfile') + else: + warnings.warn('No audio backend is available.') + set_audio_backend(None) + + +def get_audio_backend() -> Optional[str]: + """Get the name of the current backend + + Returns: + Optional[str]: The name of the current backend or ``None`` if no backend is assigned. + """ + if torchaudio.load == no_backend.load: + return None + if torchaudio.load == sox_io_backend.load: + return 'sox_io' + if torchaudio.load == soundfile_backend.load: + return 'soundfile' + raise ValueError('Unknown backend.') diff --git a/torchaudio/compliance/__init__.py b/torchaudio/compliance/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..795065dc8d4becb5a3ad8a65c652804b0422514c --- /dev/null +++ b/torchaudio/compliance/__init__.py @@ -0,0 +1,5 @@ +from . import kaldi + +__all__ = [ + 'kaldi', +] diff --git a/torchaudio/compliance/kaldi.py b/torchaudio/compliance/kaldi.py new file mode 100644 index 0000000000000000000000000000000000000000..60442f0d609ba70327786070564b78173ef2914a --- /dev/null +++ b/torchaudio/compliance/kaldi.py @@ -0,0 +1,750 @@ +from typing import Tuple + +import math +import torch +from torch import Tensor + +import torchaudio + +__all__ = [ + 'get_mel_banks', + 'inverse_mel_scale', + 'inverse_mel_scale_scalar', + 'mel_scale', + 'mel_scale_scalar', + 'spectrogram', + 'fbank', + 'mfcc', + 'vtln_warp_freq', + 'vtln_warp_mel_freq', +] + +# numeric_limits::epsilon() 1.1920928955078125e-07 +EPSILON = torch.tensor(torch.finfo(torch.float).eps) +# 1 milliseconds = 0.001 seconds +MILLISECONDS_TO_SECONDS = 0.001 + +# window types +HAMMING = 'hamming' +HANNING = 'hanning' +POVEY = 'povey' +RECTANGULAR = 'rectangular' +BLACKMAN = 'blackman' +WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN] + + +def _get_epsilon(device, dtype): + return EPSILON.to(device=device, dtype=dtype) + + +def _next_power_of_2(x: int) -> int: + r"""Returns the smallest power of 2 that is greater than x + """ + return 1 if x == 0 else 2 ** (x - 1).bit_length() + + +def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor: + r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``) + representing how the window is shifted along the waveform. Each row is a frame. + + Args: + waveform (Tensor): Tensor of size ``num_samples`` + window_size (int): Frame length + window_shift (int): Frame shift + snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. + + Returns: + Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame + """ + assert waveform.dim() == 1 + num_samples = waveform.size(0) + strides = (window_shift * waveform.stride(0), waveform.stride(0)) + + if snip_edges: + if num_samples < window_size: + return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device) + else: + m = 1 + (num_samples - window_size) // window_shift + else: + reversed_waveform = torch.flip(waveform, [0]) + m = (num_samples + (window_shift // 2)) // window_shift + pad = window_size // 2 - window_shift // 2 + pad_right = reversed_waveform + if pad > 0: + # torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect' + # but we want [2, 1, 0, 0, 1, 2] + pad_left = reversed_waveform[-pad:] + waveform = torch.cat((pad_left, waveform, pad_right), dim=0) + else: + # pad is negative so we want to trim the waveform at the front + waveform = torch.cat((waveform[-pad:], pad_right), dim=0) + + sizes = (m, window_size) + return waveform.as_strided(sizes, strides) + + +def _feature_window_function(window_type: str, + window_size: int, + blackman_coeff: float, + device: torch.device, + dtype: int, + ) -> Tensor: + r"""Returns a window function with the given type and size + """ + if window_type == HANNING: + return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype) + elif window_type == HAMMING: + return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype) + elif window_type == POVEY: + # like hanning but goes to zero at edges + return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85) + elif window_type == RECTANGULAR: + return torch.ones(window_size, device=device, dtype=dtype) + elif window_type == BLACKMAN: + a = 2 * math.pi / (window_size - 1) + window_function = torch.arange(window_size, device=device, dtype=dtype) + # can't use torch.blackman_window as they use different coefficients + return (blackman_coeff - 0.5 * torch.cos(a * window_function) + + (0.5 - blackman_coeff) * torch.cos(2 * a * window_function)).to(device=device, dtype=dtype) + else: + raise Exception('Invalid window type ' + window_type) + + +def _get_log_energy(strided_input: Tensor, + epsilon: Tensor, + energy_floor: float) -> Tensor: + r"""Returns the log energy of size (m) for a strided_input (m,*) + """ + device, dtype = strided_input.device, strided_input.dtype + log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m) + if energy_floor == 0.0: + return log_energy + return torch.max( + log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype)) + + +def _get_waveform_and_window_properties(waveform: Tensor, + channel: int, + sample_frequency: float, + frame_shift: float, + frame_length: float, + round_to_power_of_two: bool, + preemphasis_coefficient: float) -> Tuple[Tensor, int, int, int]: + r"""Gets the waveform and window properties + """ + channel = max(channel, 0) + assert channel < waveform.size(0), ('Invalid channel {} for size {}'.format(channel, waveform.size(0))) + waveform = waveform[channel, :] # size (n) + window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS) + window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS) + padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size + + assert 2 <= window_size <= len( + waveform), ('choose a window size {} that is [2, {}]' + .format(window_size, len(waveform))) + assert 0 < window_shift, '`window_shift` must be greater than 0' + assert padded_window_size % 2 == 0, 'the padded `window_size` must be divisible by two.' \ + ' use `round_to_power_of_two` or change `frame_length`' + assert 0. <= preemphasis_coefficient <= 1.0, '`preemphasis_coefficient` must be between [0,1]' + assert sample_frequency > 0, '`sample_frequency` must be greater than zero' + return waveform, window_shift, window_size, padded_window_size + + +def _get_window(waveform: Tensor, + padded_window_size: int, + window_size: int, + window_shift: int, + window_type: str, + blackman_coeff: float, + snip_edges: bool, + raw_energy: bool, + energy_floor: float, + dither: float, + remove_dc_offset: bool, + preemphasis_coefficient: float) -> Tuple[Tensor, Tensor]: + r"""Gets a window and its log energy + + Returns: + (Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m) + """ + device, dtype = waveform.device, waveform.dtype + epsilon = _get_epsilon(device, dtype) + + # size (m, window_size) + strided_input = _get_strided(waveform, window_size, window_shift, snip_edges) + + if dither != 0.0: + # Returns a random number strictly between 0 and 1 + x = torch.max(epsilon, torch.rand(strided_input.shape, device=device, dtype=dtype)) + rand_gauss = torch.sqrt(-2 * x.log()) * torch.cos(2 * math.pi * x) + strided_input = strided_input + rand_gauss * dither + + if remove_dc_offset: + # Subtract each row/frame by its mean + row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1) + strided_input = strided_input - row_means + + if raw_energy: + # Compute the log energy of each row/frame before applying preemphasis and + # window function + signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m) + + if preemphasis_coefficient != 0.0: + # strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j + offset_strided_input = torch.nn.functional.pad( + strided_input.unsqueeze(0), (1, 0), mode='replicate').squeeze(0) # size (m, window_size + 1) + strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1] + + # Apply window_function to each row/frame + window_function = _feature_window_function( + window_type, window_size, blackman_coeff, device, dtype).unsqueeze(0) # size (1, window_size) + strided_input = strided_input * window_function # size (m, window_size) + + # Pad columns with zero until we reach size (m, padded_window_size) + if padded_window_size != window_size: + padding_right = padded_window_size - window_size + strided_input = torch.nn.functional.pad( + strided_input.unsqueeze(0), (0, padding_right), mode='constant', value=0).squeeze(0) + + # Compute energy after window function (not the raw one) + if not raw_energy: + signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m) + + return strided_input, signal_log_energy + + +def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor: + # subtracts the column mean of the tensor size (m, n) if subtract_mean=True + # it returns size (m, n) + if subtract_mean: + col_means = torch.mean(tensor, dim=0).unsqueeze(0) + tensor = tensor - col_means + return tensor + + +def spectrogram(waveform: Tensor, + blackman_coeff: float = 0.42, + channel: int = -1, + dither: float = 0.0, + energy_floor: float = 1.0, + frame_length: float = 25.0, + frame_shift: float = 10.0, + min_duration: float = 0.0, + preemphasis_coefficient: float = 0.97, + raw_energy: bool = True, + remove_dc_offset: bool = True, + round_to_power_of_two: bool = True, + sample_frequency: float = 16000.0, + snip_edges: bool = True, + subtract_mean: bool = False, + window_type: str = POVEY) -> Tensor: + r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's + compute-spectrogram-feats. + + Args: + waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) + blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``) + channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``) + dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set + the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``) + energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution: + this floor is applied to the zeroth component, representing the total signal energy. The floor on the + individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: ``1.0``) + frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``) + frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``) + min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``) + preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``) + raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``) + remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``) + round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input + to FFT. (Default: ``True``) + sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if + specified there) (Default: ``16000.0``) + snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``) + subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do + it this way. (Default: ``False``) + window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') + (Default: ``'povey'``) + + Returns: + Tensor: A spectrogram identical to what Kaldi would output. The shape is + (m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided + """ + device, dtype = waveform.device, waveform.dtype + epsilon = _get_epsilon(device, dtype) + + waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties( + waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient) + + if len(waveform) < min_duration * sample_frequency: + # signal is too short + return torch.empty(0) + + strided_input, signal_log_energy = _get_window( + waveform, padded_window_size, window_size, window_shift, window_type, blackman_coeff, + snip_edges, raw_energy, energy_floor, dither, remove_dc_offset, preemphasis_coefficient) + + # size (m, padded_window_size // 2 + 1, 2) + fft = torch.fft.rfft(strided_input) + + # Convert the FFT into a power spectrum + power_spectrum = torch.max(fft.abs().pow(2.), epsilon).log() # size (m, padded_window_size // 2 + 1) + power_spectrum[:, 0] = signal_log_energy + + power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean) + return power_spectrum + + +def inverse_mel_scale_scalar(mel_freq: float) -> float: + return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0) + + +def inverse_mel_scale(mel_freq: Tensor) -> Tensor: + return 700.0 * ((mel_freq / 1127.0).exp() - 1.0) + + +def mel_scale_scalar(freq: float) -> float: + return 1127.0 * math.log(1.0 + freq / 700.0) + + +def mel_scale(freq: Tensor) -> Tensor: + return 1127.0 * (1.0 + freq / 700.0).log() + + +def vtln_warp_freq(vtln_low_cutoff: float, + vtln_high_cutoff: float, + low_freq: float, + high_freq: float, + vtln_warp_factor: float, + freq: Tensor) -> Tensor: + r"""This computes a VTLN warping function that is not the same as HTK's one, + but has similar inputs (this function has the advantage of never producing + empty bins). + + This function computes a warp function F(freq), defined between low_freq + and high_freq inclusive, with the following properties: + F(low_freq) == low_freq + F(high_freq) == high_freq + The function is continuous and piecewise linear with two inflection + points. + The lower inflection point (measured in terms of the unwarped + frequency) is at frequency l, determined as described below. + The higher inflection point is at a frequency h, determined as + described below. + If l <= f <= h, then F(f) = f/vtln_warp_factor. + If the higher inflection point (measured in terms of the unwarped + frequency) is at h, then max(h, F(h)) == vtln_high_cutoff. + Since (by the last point) F(h) == h/vtln_warp_factor, then + max(h, h/vtln_warp_factor) == vtln_high_cutoff, so + h = vtln_high_cutoff / max(1, 1/vtln_warp_factor). + = vtln_high_cutoff * min(1, vtln_warp_factor). + If the lower inflection point (measured in terms of the unwarped + frequency) is at l, then min(l, F(l)) == vtln_low_cutoff + This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor) + = vtln_low_cutoff * max(1, vtln_warp_factor) + Args: + vtln_low_cutoff (float): Lower frequency cutoffs for VTLN + vtln_high_cutoff (float): Upper frequency cutoffs for VTLN + low_freq (float): Lower frequency cutoffs in mel computation + high_freq (float): Upper frequency cutoffs in mel computation + vtln_warp_factor (float): Vtln warp factor + freq (Tensor): given frequency in Hz + + Returns: + Tensor: Freq after vtln warp + """ + assert vtln_low_cutoff > low_freq, 'be sure to set the vtln_low option higher than low_freq' + assert vtln_high_cutoff < high_freq, 'be sure to set the vtln_high option lower than high_freq [or negative]' + l = vtln_low_cutoff * max(1.0, vtln_warp_factor) + h = vtln_high_cutoff * min(1.0, vtln_warp_factor) + scale = 1.0 / vtln_warp_factor + Fl = scale * l # F(l) + Fh = scale * h # F(h) + assert l > low_freq and h < high_freq + # slope of left part of the 3-piece linear function + scale_left = (Fl - low_freq) / (l - low_freq) + # [slope of center part is just "scale"] + + # slope of right part of the 3-piece linear function + scale_right = (high_freq - Fh) / (high_freq - h) + + res = torch.empty_like(freq) + + outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(freq, high_freq) # freq < low_freq || freq > high_freq + before_l = torch.lt(freq, l) # freq < l + before_h = torch.lt(freq, h) # freq < h + after_h = torch.ge(freq, h) # freq >= h + + # order of operations matter here (since there is overlapping frequency regions) + res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq) + res[before_h] = scale * freq[before_h] + res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq) + res[outside_low_high_freq] = freq[outside_low_high_freq] + + return res + + +def vtln_warp_mel_freq(vtln_low_cutoff: float, + vtln_high_cutoff: float, + low_freq, high_freq: float, + vtln_warp_factor: float, + mel_freq: Tensor) -> Tensor: + r""" + Args: + vtln_low_cutoff (float): Lower frequency cutoffs for VTLN + vtln_high_cutoff (float): Upper frequency cutoffs for VTLN + low_freq (float): Lower frequency cutoffs in mel computation + high_freq (float): Upper frequency cutoffs in mel computation + vtln_warp_factor (float): Vtln warp factor + mel_freq (Tensor): Given frequency in Mel + + Returns: + Tensor: ``mel_freq`` after vtln warp + """ + return mel_scale(vtln_warp_freq(vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, + vtln_warp_factor, inverse_mel_scale(mel_freq))) + + +def get_mel_banks(num_bins: int, + window_length_padded: int, + sample_freq: float, + low_freq: float, + high_freq: float, + vtln_low: float, + vtln_high: float, + vtln_warp_factor: float) -> Tuple[Tensor, Tensor]: + """ + Returns: + (Tensor, Tensor): The tuple consists of ``bins`` (which is + melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is + center frequencies of bins of size (``num_bins``)). + """ + assert num_bins > 3, 'Must have at least 3 mel bins' + assert window_length_padded % 2 == 0 + num_fft_bins = window_length_padded / 2 + nyquist = 0.5 * sample_freq + + if high_freq <= 0.0: + high_freq += nyquist + + assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), \ + ('Bad values in options: low-freq {} and high-freq {} vs. nyquist {}'.format(low_freq, high_freq, nyquist)) + + # fft-bin width [think of it as Nyquist-freq / half-window-length] + fft_bin_width = sample_freq / window_length_padded + mel_low_freq = mel_scale_scalar(low_freq) + mel_high_freq = mel_scale_scalar(high_freq) + + # divide by num_bins+1 in next line because of end-effects where the bins + # spread out to the sides. + mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1) + + if vtln_high < 0.0: + vtln_high += nyquist + + assert vtln_warp_factor == 1.0 or ((low_freq < vtln_low < high_freq) and + (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)), \ + ('Bad values in options: vtln-low {} and vtln-high {}, versus ' + 'low-freq {} and high-freq {}'.format(vtln_low, vtln_high, low_freq, high_freq)) + + bin = torch.arange(num_bins).unsqueeze(1) + left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1) + center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1) + right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1) + + if vtln_warp_factor != 1.0: + left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel) + center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel) + right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel) + + center_freqs = inverse_mel_scale(center_mel) # size (num_bins) + # size(1, num_fft_bins) + mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0) + + # size (num_bins, num_fft_bins) + up_slope = (mel - left_mel) / (center_mel - left_mel) + down_slope = (right_mel - mel) / (right_mel - center_mel) + + if vtln_warp_factor == 1.0: + # left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values + bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope)) + else: + # warping can move the order of left_mel, center_mel, right_mel anywhere + bins = torch.zeros_like(up_slope) + up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel) # left_mel < mel <= center_mel + down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel) # center_mel < mel < right_mel + bins[up_idx] = up_slope[up_idx] + bins[down_idx] = down_slope[down_idx] + + return bins, center_freqs + + +def fbank(waveform: Tensor, + blackman_coeff: float = 0.42, + channel: int = -1, + dither: float = 0.0, + energy_floor: float = 1.0, + frame_length: float = 25.0, + frame_shift: float = 10.0, + high_freq: float = 0.0, + htk_compat: bool = False, + low_freq: float = 20.0, + min_duration: float = 0.0, + num_mel_bins: int = 23, + preemphasis_coefficient: float = 0.97, + raw_energy: bool = True, + remove_dc_offset: bool = True, + round_to_power_of_two: bool = True, + sample_frequency: float = 16000.0, + snip_edges: bool = True, + subtract_mean: bool = False, + use_energy: bool = False, + use_log_fbank: bool = True, + use_power: bool = True, + vtln_high: float = -500.0, + vtln_low: float = 100.0, + vtln_warp: float = 1.0, + window_type: str = POVEY) -> Tensor: + r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's + compute-fbank-feats. + + Args: + waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) + blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``) + channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``) + dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set + the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``) + energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution: + this floor is applied to the zeroth component, representing the total signal energy. The floor on the + individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: ``1.0``) + frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``) + frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``) + high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist) + (Default: ``0.0``) + htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features + (need to change other parameters). (Default: ``False``) + low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``) + min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``) + num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``) + preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``) + raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``) + remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``) + round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input + to FFT. (Default: ``True``) + sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if + specified there) (Default: ``16000.0``) + snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``) + subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do + it this way. (Default: ``False``) + use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``) + use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``) + use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``) + vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if + negative, offset from high-mel-freq (Default: ``-500.0``) + vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``) + vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``) + window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') + (Default: ``'povey'``) + + Returns: + Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``) + where m is calculated in _get_strided + """ + device, dtype = waveform.device, waveform.dtype + + waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties( + waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient) + + if len(waveform) < min_duration * sample_frequency: + # signal is too short + return torch.empty(0, device=device, dtype=dtype) + + # strided_input, size (m, padded_window_size) and signal_log_energy, size (m) + strided_input, signal_log_energy = _get_window( + waveform, padded_window_size, window_size, window_shift, window_type, blackman_coeff, + snip_edges, raw_energy, energy_floor, dither, remove_dc_offset, preemphasis_coefficient) + + # size (m, padded_window_size // 2 + 1) + spectrum = torch.fft.rfft(strided_input).abs() + if use_power: + spectrum = spectrum.pow(2.) + + # size (num_mel_bins, padded_window_size // 2) + mel_energies, _ = get_mel_banks(num_mel_bins, padded_window_size, sample_frequency, + low_freq, high_freq, vtln_low, vtln_high, vtln_warp) + mel_energies = mel_energies.to(device=device, dtype=dtype) + + # pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1) + mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode='constant', value=0) + + # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins) + mel_energies = torch.mm(spectrum, mel_energies.T) + if use_log_fbank: + # avoid log of zero (which should be prevented anyway by dithering) + mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log() + + # if use_energy then add it as the last column for htk_compat == true else first column + if use_energy: + signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1) + # returns size (m, num_mel_bins + 1) + if htk_compat: + mel_energies = torch.cat((mel_energies, signal_log_energy), dim=1) + else: + mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1) + + mel_energies = _subtract_column_mean(mel_energies, subtract_mean) + return mel_energies + + +def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor: + # returns a dct matrix of size (num_mel_bins, num_ceps) + # size (num_mel_bins, num_mel_bins) + dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, 'ortho') + # kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins) + # this would be the first column in the dct_matrix for torchaudio as it expects a + # right multiply (which would be the first column of the kaldi's dct_matrix as kaldi + # expects a left multiply e.g. dct_matrix * vector). + dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins)) + dct_matrix = dct_matrix[:, :num_ceps] + return dct_matrix + + +def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor: + # returns size (num_ceps) + # Compute liftering coefficients (scaling on cepstral coeffs) + # coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected. + i = torch.arange(num_ceps) + return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter) + + +def mfcc( + waveform: Tensor, + blackman_coeff: float = 0.42, + cepstral_lifter: float = 22.0, + channel: int = -1, + dither: float = 0.0, + energy_floor: float = 1.0, + frame_length: float = 25.0, + frame_shift: float = 10.0, + high_freq: float = 0.0, + htk_compat: bool = False, + low_freq: float = 20.0, + num_ceps: int = 13, + min_duration: float = 0.0, + num_mel_bins: int = 23, + preemphasis_coefficient: float = 0.97, + raw_energy: bool = True, + remove_dc_offset: bool = True, + round_to_power_of_two: bool = True, + sample_frequency: float = 16000.0, + snip_edges: bool = True, + subtract_mean: bool = False, + use_energy: bool = False, + vtln_high: float = -500.0, + vtln_low: float = 100.0, + vtln_warp: float = 1.0, + window_type: str = POVEY) -> Tensor: + r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's + compute-mfcc-feats. + + Args: + waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) + blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``) + cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``) + channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``) + dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set + the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``) + energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution: + this floor is applied to the zeroth component, representing the total signal energy. The floor on the + individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: ``1.0``) + frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``) + frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``) + high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist) + (Default: ``0.0``) + htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible + features (need to change other parameters). (Default: ``False``) + low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``) + num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``) + min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``) + num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``) + preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``) + raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``) + remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``) + round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input + to FFT. (Default: ``True``) + sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if + specified there) (Default: ``16000.0``) + snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``) + subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do + it this way. (Default: ``False``) + use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``) + vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if + negative, offset from high-mel-freq (Default: ``-500.0``) + vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``) + vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``) + window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') + (Default: ``"povey"``) + + Returns: + Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``) + where m is calculated in _get_strided + """ + assert num_ceps <= num_mel_bins, 'num_ceps cannot be larger than num_mel_bins: %d vs %d' % (num_ceps, num_mel_bins) + + device, dtype = waveform.device, waveform.dtype + + # The mel_energies should not be squared (use_power=True), not have mean subtracted + # (subtract_mean=False), and use log (use_log_fbank=True). + # size (m, num_mel_bins + use_energy) + feature = fbank(waveform=waveform, blackman_coeff=blackman_coeff, channel=channel, + dither=dither, energy_floor=energy_floor, frame_length=frame_length, + frame_shift=frame_shift, high_freq=high_freq, htk_compat=htk_compat, + low_freq=low_freq, min_duration=min_duration, num_mel_bins=num_mel_bins, + preemphasis_coefficient=preemphasis_coefficient, raw_energy=raw_energy, + remove_dc_offset=remove_dc_offset, round_to_power_of_two=round_to_power_of_two, + sample_frequency=sample_frequency, snip_edges=snip_edges, subtract_mean=False, + use_energy=use_energy, use_log_fbank=True, use_power=True, + vtln_high=vtln_high, vtln_low=vtln_low, vtln_warp=vtln_warp, window_type=window_type) + + if use_energy: + # size (m) + signal_log_energy = feature[:, num_mel_bins if htk_compat else 0] + # offset is 0 if htk_compat==True else 1 + mel_offset = int(not htk_compat) + feature = feature[:, mel_offset:(num_mel_bins + mel_offset)] + + # size (num_mel_bins, num_ceps) + dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device) + + # size (m, num_ceps) + feature = feature.matmul(dct_matrix) + + if cepstral_lifter != 0.0: + # size (1, num_ceps) + lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0) + feature *= lifter_coeffs.to(device=device, dtype=dtype) + + # if use_energy then replace the last column for htk_compat == true else first column + if use_energy: + feature[:, 0] = signal_log_energy + + if htk_compat: + energy = feature[:, 0].unsqueeze(1) # size (m, 1) + feature = feature[:, 1:] # size (m, num_ceps - 1) + if not use_energy: + # scale on C0 (actually removing a scale we previously added that's + # part of one common definition of the cosine transform.) + energy *= math.sqrt(2) + + feature = torch.cat((feature, energy), dim=1) + + feature = _subtract_column_mean(feature, subtract_mean) + return feature diff --git a/torchaudio/csrc/CMakeLists.txt b/torchaudio/csrc/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..2187abafa1baba72b442b6e3364325c0ecfec99e --- /dev/null +++ b/torchaudio/csrc/CMakeLists.txt @@ -0,0 +1,171 @@ +get_property(TORCHAUDIO_THIRD_PARTIES GLOBAL PROPERTY TORCHAUDIO_THIRD_PARTIES) + +################################################################################ +# libtorchaudio +################################################################################ +set( + LIBTORCHAUDIO_SOURCES + lfilter.cpp + overdrive.cpp + utils.cpp + ) + +if(BUILD_RNNT) + list( + APPEND + LIBTORCHAUDIO_SOURCES + rnnt/cpu/compute_alphas.cpp + rnnt/cpu/compute_betas.cpp + rnnt/cpu/compute.cpp + rnnt/compute_alphas.cpp + rnnt/compute_betas.cpp + rnnt/compute.cpp + rnnt/autograd.cpp + ) + if (USE_CUDA) + list( + APPEND + LIBTORCHAUDIO_SOURCES + rnnt/gpu/compute_alphas.cu + rnnt/gpu/compute_betas.cu + rnnt/gpu/compute.cu + ) + endif() +endif() + +if(BUILD_KALDI) + list(APPEND LIBTORCHAUDIO_SOURCES kaldi.cpp) +endif() + +if(BUILD_SOX) + list( + APPEND + LIBTORCHAUDIO_SOURCES + sox/io.cpp + sox/utils.cpp + sox/effects.cpp + sox/effects_chain.cpp + sox/types.cpp + ) +endif() + +add_library( + libtorchaudio + SHARED + ${LIBTORCHAUDIO_SOURCES} + ) +set_target_properties(libtorchaudio PROPERTIES PREFIX "") + +target_include_directories( + libtorchaudio + PRIVATE + ${PROJECT_SOURCE_DIR} + ) + +target_link_libraries( + libtorchaudio + torch + ${TORCHAUDIO_THIRD_PARTIES} + ) + +if (BUILD_SOX) + target_compile_definitions(libtorchaudio PUBLIC INCLUDE_SOX) +endif() + +if (BUILD_KALDI) + target_compile_definitions(libtorchaudio PUBLIC INCLUDE_KALDI) +endif() + +if(USE_CUDA) + target_compile_definitions(libtorchaudio PRIVATE USE_CUDA) + target_include_directories( + libtorchaudio + PRIVATE + ${CUDA_TOOLKIT_INCLUDE} + ) + target_link_libraries( + libtorchaudio + ${C10_CUDA_LIBRARY} + ${CUDA_CUDART_LIBRARY} + ) +endif() + +if (MSVC) + set_target_properties(libtorchaudio PROPERTIES SUFFIX ".pyd") +endif(MSVC) + +install( + TARGETS libtorchaudio + LIBRARY DESTINATION lib + RUNTIME DESTINATION lib # For Windows + ) + +if (APPLE) + set(TORCHAUDIO_LIBRARY libtorchaudio CACHE INTERNAL "") +else() + set(TORCHAUDIO_LIBRARY -Wl,--no-as-needed libtorchaudio -Wl,--as-needed CACHE INTERNAL "") +endif() + +################################################################################ +# _torchaudio.so +################################################################################ +if (BUILD_TORCHAUDIO_PYTHON_EXTENSION) + set( + EXTENSION_SOURCES + pybind/pybind.cpp + ) + if(BUILD_SOX) + list( + APPEND + EXTENSION_SOURCES + pybind/sox/effects.cpp + pybind/sox/effects_chain.cpp + pybind/sox/io.cpp + pybind/sox/utils.cpp + ) + endif() + add_library( + _torchaudio + SHARED + ${EXTENSION_SOURCES} + ) + + set_target_properties(_torchaudio PROPERTIES PREFIX "") + if (MSVC) + set_target_properties(_torchaudio PROPERTIES SUFFIX ".pyd") + endif(MSVC) + + if (APPLE) + # https://github.com/facebookarchive/caffe2/issues/854#issuecomment-364538485 + # https://github.com/pytorch/pytorch/commit/73f6715f4725a0723d8171d3131e09ac7abf0666 + set_target_properties(_torchaudio PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") + endif() + + target_include_directories( + _torchaudio + PRIVATE + ${PROJECT_SOURCE_DIR} + ${Python_INCLUDE_DIR} + ) + + # See https://github.com/pytorch/pytorch/issues/38122 + find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib") + + if (WIN32) + find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development) + set(ADDITIONAL_ITEMS Python3::Python) + endif() + + target_link_libraries( + _torchaudio + libtorchaudio + ${TORCH_PYTHON_LIBRARY} + ${ADDITIONAL_ITEMS} + ) + + install( + TARGETS _torchaudio + LIBRARY DESTINATION . + RUNTIME DESTINATION . # For Windows + ) +endif() diff --git a/torchaudio/csrc/kaldi.cpp b/torchaudio/csrc/kaldi.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6f2b36c28fe4256c2ba54e512f0b884717572c4f --- /dev/null +++ b/torchaudio/csrc/kaldi.cpp @@ -0,0 +1,93 @@ +#include +#include "feat/pitch-functions.h" + +namespace torchaudio { +namespace kaldi { + +namespace { + +torch::Tensor denormalize(const torch::Tensor& t) { + auto ret = t; + auto pos = t > 0, neg = t < 0; + ret.index_put({pos}, t.index({pos}) * 32767); + ret.index_put({neg}, t.index({neg}) * 32768); + return ret; +} + +torch::Tensor compute_kaldi_pitch( + const torch::Tensor& wave, + const ::kaldi::PitchExtractionOptions& opts) { + ::kaldi::VectorBase<::kaldi::BaseFloat> input(wave); + ::kaldi::Matrix<::kaldi::BaseFloat> output; + ::kaldi::ComputeKaldiPitch(opts, input, &output); + return output.tensor_; +} + +} // namespace + +torch::Tensor ComputeKaldiPitch( + const torch::Tensor& wave, + double sample_frequency, + double frame_length, + double frame_shift, + double min_f0, + double max_f0, + double soft_min_f0, + double penalty_factor, + double lowpass_cutoff, + double resample_frequency, + double delta_pitch, + double nccf_ballast, + int64_t lowpass_filter_width, + int64_t upsample_filter_width, + int64_t max_frames_latency, + int64_t frames_per_chunk, + bool simulate_first_pass_online, + int64_t recompute_frame, + bool snip_edges) { + TORCH_CHECK(wave.ndimension() == 2, "Input tensor must be 2 dimentional."); + TORCH_CHECK(wave.device().is_cpu(), "Input tensor must be on CPU."); + TORCH_CHECK( + wave.dtype() == torch::kFloat32, "Input tensor must be float32 type."); + + ::kaldi::PitchExtractionOptions opts; + opts.samp_freq = static_cast<::kaldi::BaseFloat>(sample_frequency); + opts.frame_shift_ms = static_cast<::kaldi::BaseFloat>(frame_shift); + opts.frame_length_ms = static_cast<::kaldi::BaseFloat>(frame_length); + opts.min_f0 = static_cast<::kaldi::BaseFloat>(min_f0); + opts.max_f0 = static_cast<::kaldi::BaseFloat>(max_f0); + opts.soft_min_f0 = static_cast<::kaldi::BaseFloat>(soft_min_f0); + opts.penalty_factor = static_cast<::kaldi::BaseFloat>(penalty_factor); + opts.lowpass_cutoff = static_cast<::kaldi::BaseFloat>(lowpass_cutoff); + opts.resample_freq = static_cast<::kaldi::BaseFloat>(resample_frequency); + opts.delta_pitch = static_cast<::kaldi::BaseFloat>(delta_pitch); + opts.lowpass_filter_width = static_cast<::kaldi::int32>(lowpass_filter_width); + opts.upsample_filter_width = + static_cast<::kaldi::int32>(upsample_filter_width); + opts.max_frames_latency = static_cast<::kaldi::int32>(max_frames_latency); + opts.frames_per_chunk = static_cast<::kaldi::int32>(frames_per_chunk); + opts.simulate_first_pass_online = simulate_first_pass_online; + opts.recompute_frame = static_cast<::kaldi::int32>(recompute_frame); + opts.snip_edges = snip_edges; + + // Kaldi's float type expects value range of int16 expressed as float + torch::Tensor wave_ = denormalize(wave); + + auto batch_size = wave_.size(0); + std::vector results(batch_size); + at::parallel_for(0, batch_size, 1, [&](int64_t begin, int64_t end) { + for (auto i = begin; i < end; ++i) { + results[i] = compute_kaldi_pitch(wave_.index({i}), opts); + } + }); + return torch::stack(results, 0); +} + +TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def( + "torchaudio::kaldi_ComputeKaldiPitch", + &torchaudio::kaldi::ComputeKaldiPitch); +} + +} // namespace kaldi +} // namespace torchaudio diff --git a/torchaudio/csrc/lfilter.cpp b/torchaudio/csrc/lfilter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..eec34d1b88873e107f17978e4b4d7d2d5dbb062e --- /dev/null +++ b/torchaudio/csrc/lfilter.cpp @@ -0,0 +1,283 @@ +#include +#include + +namespace { + +template +void host_lfilter_core_loop( + const torch::Tensor& input_signal_windows, + const torch::Tensor& a_coeff_flipped, + torch::Tensor& padded_output_waveform) { + int64_t n_batch = input_signal_windows.size(0); + int64_t n_channel = input_signal_windows.size(1); + int64_t n_samples_input = input_signal_windows.size(2); + int64_t n_samples_output = padded_output_waveform.size(2); + int64_t n_order = a_coeff_flipped.size(1); + scalar_t* output_data = padded_output_waveform.data_ptr(); + const scalar_t* input_data = input_signal_windows.data_ptr(); + const scalar_t* a_coeff_flipped_data = a_coeff_flipped.data_ptr(); + + at::parallel_for(0, n_channel * n_batch, 1, [&](int64_t begin, int64_t end) { + for (auto i = begin; i < end; i++) { + int64_t offset_input = i * n_samples_input; + int64_t offset_output = i * n_samples_output; + int64_t i_channel = i % n_channel; + for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) { + scalar_t a0 = input_data[offset_input + i_sample]; + for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) { + a0 -= output_data[offset_output + i_sample + i_coeff] * + a_coeff_flipped_data[i_coeff + i_channel * n_order]; + } + output_data[offset_output + i_sample + n_order - 1] = a0; + } + } + }); +} + +void cpu_lfilter_core_loop( + const torch::Tensor& input_signal_windows, + const torch::Tensor& a_coeff_flipped, + torch::Tensor& padded_output_waveform) { + TORCH_CHECK( + input_signal_windows.device().is_cpu() && + a_coeff_flipped.device().is_cpu() && + padded_output_waveform.device().is_cpu()); + + TORCH_CHECK( + input_signal_windows.is_contiguous() && a_coeff_flipped.is_contiguous() && + padded_output_waveform.is_contiguous()); + + TORCH_CHECK( + (input_signal_windows.dtype() == torch::kFloat32 || + input_signal_windows.dtype() == torch::kFloat64) && + (a_coeff_flipped.dtype() == torch::kFloat32 || + a_coeff_flipped.dtype() == torch::kFloat64) && + (padded_output_waveform.dtype() == torch::kFloat32 || + padded_output_waveform.dtype() == torch::kFloat64)); + + TORCH_CHECK(input_signal_windows.size(0) == padded_output_waveform.size(0)); + TORCH_CHECK(input_signal_windows.size(1) == padded_output_waveform.size(1)); + + TORCH_CHECK( + input_signal_windows.size(2) + a_coeff_flipped.size(1) - 1 == + padded_output_waveform.size(2)); + + AT_DISPATCH_FLOATING_TYPES( + input_signal_windows.scalar_type(), "lfilter_core_loop", [&] { + host_lfilter_core_loop( + input_signal_windows, a_coeff_flipped, padded_output_waveform); + }); +} + +void lfilter_core_generic_loop( + const torch::Tensor& input_signal_windows, + const torch::Tensor& a_coeff_flipped, + torch::Tensor& padded_output_waveform) { + int64_t n_samples_input = input_signal_windows.size(2); + int64_t n_order = a_coeff_flipped.size(1); + auto coeff = a_coeff_flipped.unsqueeze(2); + for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) { + auto windowed_output_signal = + padded_output_waveform + .index( + {torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(i_sample, i_sample + n_order)}) + .transpose(0, 1); + auto o0 = + input_signal_windows.index( + {torch::indexing::Slice(), torch::indexing::Slice(), i_sample}) - + at::matmul(windowed_output_signal, coeff).squeeze(2).transpose(0, 1); + padded_output_waveform.index_put_( + {torch::indexing::Slice(), + torch::indexing::Slice(), + i_sample + n_order - 1}, + o0); + } +} + +class DifferentiableIIR : public torch::autograd::Function { + public: + static torch::Tensor forward( + torch::autograd::AutogradContext* ctx, + const torch::Tensor& waveform, + const torch::Tensor& a_coeffs_normalized) { + auto device = waveform.device(); + auto dtype = waveform.dtype(); + int64_t n_batch = waveform.size(0); + int64_t n_channel = waveform.size(1); + int64_t n_sample = waveform.size(2); + int64_t n_order = a_coeffs_normalized.size(1); + int64_t n_sample_padded = n_sample + n_order - 1; + + auto a_coeff_flipped = a_coeffs_normalized.flip(1).contiguous(); + + auto options = torch::TensorOptions().dtype(dtype).device(device); + auto padded_output_waveform = + torch::zeros({n_batch, n_channel, n_sample_padded}, options); + + if (device.is_cpu()) { + cpu_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform); + } else { + lfilter_core_generic_loop( + waveform, a_coeff_flipped, padded_output_waveform); + } + + auto output = padded_output_waveform.index( + {torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(n_order - 1, torch::indexing::None)}); + + ctx->save_for_backward({waveform, a_coeffs_normalized, output}); + return output; + } + + static torch::autograd::tensor_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + auto x = saved[0]; + auto a_coeffs_normalized = saved[1]; + auto y = saved[2]; + + int64_t n_batch = x.size(0); + int64_t n_channel = x.size(1); + int64_t n_order = a_coeffs_normalized.size(1); + + auto dx = torch::Tensor(); + auto da = torch::Tensor(); + auto dy = grad_outputs[0]; + + namespace F = torch::nn::functional; + + if (a_coeffs_normalized.requires_grad()) { + auto dyda = F::pad( + DifferentiableIIR::apply(-y, a_coeffs_normalized), + F::PadFuncOptions({n_order - 1, 0})); + + da = F::conv1d( + dyda.view({1, n_batch * n_channel, -1}), + dy.view({n_batch * n_channel, 1, -1}), + F::Conv1dFuncOptions().groups(n_batch * n_channel)) + .view({n_batch, n_channel, -1}) + .sum(0) + .flip(1); + } + + if (x.requires_grad()) { + dx = DifferentiableIIR::apply(dy.flip(2), a_coeffs_normalized).flip(2); + } + + return {dx, da}; + } +}; + +class DifferentiableFIR : public torch::autograd::Function { + public: + static torch::Tensor forward( + torch::autograd::AutogradContext* ctx, + const torch::Tensor& waveform, + const torch::Tensor& b_coeffs) { + int64_t n_order = b_coeffs.size(1); + int64_t n_channel = b_coeffs.size(0); + + namespace F = torch::nn::functional; + auto b_coeff_flipped = b_coeffs.flip(1).contiguous(); + auto padded_waveform = + F::pad(waveform, F::PadFuncOptions({n_order - 1, 0})); + + auto output = F::conv1d( + padded_waveform, + b_coeff_flipped.unsqueeze(1), + F::Conv1dFuncOptions().groups(n_channel)); + + ctx->save_for_backward({waveform, b_coeffs, output}); + return output; + } + + static torch::autograd::tensor_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + auto x = saved[0]; + auto b_coeffs = saved[1]; + auto y = saved[2]; + + int64_t n_batch = x.size(0); + int64_t n_channel = x.size(1); + int64_t n_order = b_coeffs.size(1); + + auto dx = torch::Tensor(); + auto db = torch::Tensor(); + auto dy = grad_outputs[0]; + + namespace F = torch::nn::functional; + + if (b_coeffs.requires_grad()) { + db = F::conv1d( + F::pad(x, F::PadFuncOptions({n_order - 1, 0})) + .view({1, n_batch * n_channel, -1}), + dy.view({n_batch * n_channel, 1, -1}), + F::Conv1dFuncOptions().groups(n_batch * n_channel)) + .view({n_batch, n_channel, -1}) + .sum(0) + .flip(1); + } + + if (x.requires_grad()) { + dx = F::conv1d( + F::pad(dy, F::PadFuncOptions({0, n_order - 1})), + b_coeffs.unsqueeze(1), + F::Conv1dFuncOptions().groups(n_channel)); + } + + return {dx, db}; + } +}; + +torch::Tensor lfilter_core( + const torch::Tensor& waveform, + const torch::Tensor& a_coeffs, + const torch::Tensor& b_coeffs) { + TORCH_CHECK(waveform.device() == a_coeffs.device()); + TORCH_CHECK(b_coeffs.device() == a_coeffs.device()); + TORCH_CHECK(a_coeffs.sizes() == b_coeffs.sizes()); + + TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 3); + TORCH_INTERNAL_ASSERT(a_coeffs.sizes().size() == 2); + TORCH_INTERNAL_ASSERT(a_coeffs.size(0) == waveform.size(1)); + + int64_t n_order = b_coeffs.size(1); + + TORCH_INTERNAL_ASSERT(n_order > 0); + + auto filtered_waveform = DifferentiableFIR::apply( + waveform, + b_coeffs / + a_coeffs.index( + {torch::indexing::Slice(), torch::indexing::Slice(0, 1)})); + + auto output = DifferentiableIIR::apply( + filtered_waveform, + a_coeffs / + a_coeffs.index( + {torch::indexing::Slice(), torch::indexing::Slice(0, 1)})); + return output; +} + +} // namespace + +// Note: We want to avoid using "catch-all" kernel. +// The following registration should be replaced with CPU specific registration. +TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop); +} + +TORCH_LIBRARY(torchaudio, m) { + m.def( + "torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(torchaudio, CompositeImplicitAutograd, m) { + m.impl("torchaudio::_lfilter", lfilter_core); +} diff --git a/torchaudio/csrc/overdrive.cpp b/torchaudio/csrc/overdrive.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4954271e411d029134f0d72572bee2008b698310 --- /dev/null +++ b/torchaudio/csrc/overdrive.cpp @@ -0,0 +1,52 @@ +#include +#include + +namespace { + +template +void overdrive_cpu_kernel( + at::TensorAccessor waveform_accessor, + at::TensorAccessor temp_accessor, + at::TensorAccessor last_in_accessor, + at::TensorAccessor last_out_accessor, + at::TensorAccessor output_waveform_accessor) { + int64_t n_frames = waveform_accessor.size(1); + int64_t n_channels = waveform_accessor.size(0); + + at::parallel_for(0, n_channels, 1, [&](int64_t begin, int64_t end) { + for (int64_t i_channel = begin; i_channel < end; ++i_channel) { + for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) { + last_out_accessor[i_channel] = temp_accessor[i_channel][i_frame] - + last_in_accessor[i_channel] + 0.995 * last_out_accessor[i_channel]; + last_in_accessor[i_channel] = temp_accessor[i_channel][i_frame]; + output_waveform_accessor[i_channel][i_frame] = + waveform_accessor[i_channel][i_frame] * 0.5 + + last_out_accessor[i_channel] * 0.75; + } + } + }); +} + +void overdrive_core_loop_cpu( + at::Tensor& waveform, + at::Tensor& temp, + at::Tensor& last_in, + at::Tensor& last_out, + at::Tensor& output_waveform) { + AT_DISPATCH_FLOATING_TYPES(waveform.scalar_type(), "overdrive_cpu", ([&] { + overdrive_cpu_kernel( + waveform.accessor(), + temp.accessor(), + last_in.accessor(), + last_out.accessor(), + output_waveform.accessor()); + })); +} + +} // namespace + +// Note: We want to avoid using "catch-all" kernel. +// The following registration should be replaced with CPU specific registration. +TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def("torchaudio::_overdrive_core_loop", &overdrive_core_loop_cpu); +} diff --git a/torchaudio/csrc/pybind/pybind.cpp b/torchaudio/csrc/pybind/pybind.cpp new file mode 100644 index 0000000000000000000000000000000000000000..583ee415155d6dbbe174ae61162198a7b6541334 --- /dev/null +++ b/torchaudio/csrc/pybind/pybind.cpp @@ -0,0 +1,27 @@ +#include + +#ifdef INCLUDE_SOX +#include +#include +#endif + +PYBIND11_MODULE(_torchaudio, m) { +#ifdef INCLUDE_SOX + m.def( + "get_info_fileobj", + &torchaudio::sox_io::get_info_fileobj, + "Get metadata of audio in file object."); + m.def( + "load_audio_fileobj", + &torchaudio::sox_io::load_audio_fileobj, + "Load audio from file object."); + m.def( + "save_audio_fileobj", + &torchaudio::sox_io::save_audio_fileobj, + "Save audio to file obj."); + m.def( + "apply_effects_fileobj", + &torchaudio::sox_effects::apply_effects_fileobj, + "Decode audio data from file-like obj and apply effects."); +#endif +} diff --git a/torchaudio/csrc/pybind/sox/effects.cpp b/torchaudio/csrc/pybind/sox/effects.cpp new file mode 100644 index 0000000000000000000000000000000000000000..43c3b08d2717055a90a9e3ba0e34a43fd8a8f3fc --- /dev/null +++ b/torchaudio/csrc/pybind/sox/effects.cpp @@ -0,0 +1,117 @@ +#include +#include +#include + +using namespace torchaudio::sox_utils; + +namespace torchaudio::sox_effects { + +// Streaming decoding over file-like object is tricky because libsox operates on +// FILE pointer. The folloing is what `sox` and `play` commands do +// - file input -> FILE pointer +// - URL input -> call wget in suprocess and pipe the data -> FILE pointer +// - stdin -> FILE pointer +// +// We want to, instead, fetch byte strings chunk by chunk, consume them, and +// discard. +// +// Here is the approach +// 1. Initialize sox_format_t using sox_open_mem_read, providing the initial +// chunk of byte string +// This will perform header-based format detection, if necessary, then fill +// the metadata of sox_format_t. Internally, sox_open_mem_read uses fmemopen, +// which returns FILE* which points the buffer of the provided byte string. +// 2. Each time sox reads a chunk from the FILE*, we update the underlying +// buffer in a way that it +// starts with unseen data, and append the new data read from the given +// fileobj. This will trick libsox as if it keeps reading from the FILE* +// continuously. +// For Step 2. see `fileobj_input_drain` function in effects_chain.cpp +auto apply_effects_fileobj( + py::object fileobj, + const std::vector>& effects, + c10::optional normalize, + c10::optional channels_first, + c10::optional format) -> std::tuple { + // Prepare the buffer used throughout the lifecycle of SoxEffectChain. + // + // For certain format (such as FLAC), libsox keeps reading the content at + // the initialization unless it reaches EOF even when the header is properly + // parsed. (Making buffer size 8192, which is way bigger than the header, + // resulted in libsox consuming all the buffer content at the time it opens + // the file.) Therefore buffer has to always contain valid data, except after + // EOF. We default to `sox_get_globals()->bufsiz`* for buffer size and we + // first check if there is enough data to fill the buffer. `read_fileobj` + // repeatedly calls `read` method until it receives the requested length of + // bytes or it reaches EOF. If we get bytes shorter than requested, that means + // the whole audio data are fetched. + // + // * This can be changed with `torchaudio.utils.sox_utils.set_buffer_size`. + const auto capacity = [&]() { + // NOTE: + // Use the abstraction provided by `libtorchaudio` to access the global + // config defined by libsox. Directly using `sox_get_globals` function will + // end up retrieving the static variable defined in `_torchaudio`, which is + // not correct. + const auto bufsiz = get_buffer_size(); + const int64_t kDefaultCapacityInBytes = 256; + return (bufsiz > kDefaultCapacityInBytes) ? bufsiz + : kDefaultCapacityInBytes; + }(); + std::string buffer(capacity, '\0'); + auto* in_buf = const_cast(buffer.data()); + auto num_read = read_fileobj(&fileobj, capacity, in_buf); + // If the file is shorter than 256, then libsox cannot read the header. + auto in_buffer_size = (num_read > 256) ? num_read : 256; + + // Open file (this starts reading the header) + // When opening a file there are two functions that can touches FILE*. + // * `auto_detect_format` + // https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L43 + // * `startread` handler of detected format. + // https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L574 + // To see the handler of a particular format, go to + // https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/.c + // For example, voribs can be found + // https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/vorbis.c#L97-L158 + SoxFormat sf(sox_open_mem_read( + in_buf, + in_buffer_size, + /*signal=*/nullptr, + /*encoding=*/nullptr, + /*filetype=*/format.has_value() ? format.value().c_str() : nullptr)); + + // In case of streamed data, length can be 0 + validate_input_memfile(sf); + + // Prepare output buffer + std::vector out_buffer; + out_buffer.reserve(sf->signal.length); + + // Create and run SoxEffectsChain + const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision); + torchaudio::sox_effects_chain::SoxEffectsChainPyBind chain( + /*input_encoding=*/sf->encoding, + /*output_encoding=*/get_tensor_encodinginfo(dtype)); + chain.addInputFileObj(sf, in_buf, in_buffer_size, &fileobj); + for (const auto& effect : effects) { + chain.addEffect(effect); + } + chain.addOutputBuffer(&out_buffer); + chain.run(); + + // Create tensor from buffer + bool channels_first_ = channels_first.value_or(true); + auto tensor = convert_to_tensor( + /*buffer=*/out_buffer.data(), + /*num_samples=*/out_buffer.size(), + /*num_channels=*/chain.getOutputNumChannels(), + dtype, + normalize.value_or(true), + channels_first_); + + return std::make_tuple( + tensor, static_cast(chain.getOutputSampleRate())); +} + +} // namespace torchaudio::sox_effects diff --git a/torchaudio/csrc/pybind/sox/effects.h b/torchaudio/csrc/pybind/sox/effects.h new file mode 100644 index 0000000000000000000000000000000000000000..5406ba24c6d9b8640399a09e39b821d1160bee3c --- /dev/null +++ b/torchaudio/csrc/pybind/sox/effects.h @@ -0,0 +1,17 @@ +#ifndef TORCHAUDIO_PYBIND_SOX_EFFECTS_H +#define TORCHAUDIO_PYBIND_SOX_EFFECTS_H + +#include + +namespace torchaudio::sox_effects { + +auto apply_effects_fileobj( + py::object fileobj, + const std::vector>& effects, + c10::optional normalize, + c10::optional channels_first, + c10::optional format) -> std::tuple; + +} // namespace torchaudio::sox_effects + +#endif diff --git a/torchaudio/csrc/pybind/sox/effects_chain.cpp b/torchaudio/csrc/pybind/sox/effects_chain.cpp new file mode 100644 index 0000000000000000000000000000000000000000..12f8d31d6650db4bb98ade63869647a60a98eaed --- /dev/null +++ b/torchaudio/csrc/pybind/sox/effects_chain.cpp @@ -0,0 +1,227 @@ +#include +#include +#include + +using namespace torchaudio::sox_utils; + +namespace torchaudio::sox_effects_chain { + +namespace { + +/// helper classes for passing file-like object to SoxEffectChain +struct FileObjInputPriv { + sox_format_t* sf; + py::object* fileobj; + bool eof_reached; + char* buffer; + uint64_t buffer_size; +}; + +struct FileObjOutputPriv { + sox_format_t* sf; + py::object* fileobj; + char** buffer; + size_t* buffer_size; +}; + +/// Callback function to feed byte string +/// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/sox.h#L1268-L1278 +auto fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) + -> int { + auto priv = static_cast(effp->priv); + auto sf = priv->sf; + auto buffer = priv->buffer; + + // 1. Refresh the buffer + // + // NOTE: + // Since the underlying FILE* was opened with `fmemopen`, the only way + // libsox detect EOF is reaching the end of the buffer. (null byte won't + // help) Therefore we need to align the content at the end of buffer, + // otherwise, libsox will keep reading the content beyond intended length. + // + // Before: + // + // |<-------consumed------>|<---remaining--->| + // |***********************|-----------------| + // ^ ftell + // + // After: + // + // |<-offset->|<---remaining--->|<-new data->| + // |**********|-----------------|++++++++++++| + // ^ ftell + + // NOTE: + // Do not use `sf->tell_off` here. Presumably, `tell_off` and `fseek` are + // supposed to be in sync, but there are cases (Vorbis) they are not + // in sync and `tell_off` has seemingly uninitialized value, which + // leads num_remain to be negative and cause segmentation fault + // in `memmove`. + const auto tell = ftell((FILE*)sf->fp); + if (tell < 0) { + throw std::runtime_error("Internal Error: ftell failed."); + } + const auto num_consumed = static_cast(tell); + if (num_consumed > priv->buffer_size) { + throw std::runtime_error("Internal Error: buffer overrun."); + } + + const auto num_remain = priv->buffer_size - num_consumed; + + // 1.1. Fetch the data to see if there is data to fill the buffer + size_t num_refill = 0; + std::string chunk(num_consumed, '\0'); + if (num_consumed && !priv->eof_reached) { + num_refill = read_fileobj( + priv->fileobj, num_consumed, const_cast(chunk.data())); + if (num_refill < num_consumed) { + priv->eof_reached = true; + } + } + const auto offset = num_consumed - num_refill; + + // 1.2. Move the unconsumed data towards the beginning of buffer. + if (num_remain) { + auto src = static_cast(buffer + num_consumed); + auto dst = static_cast(buffer + offset); + memmove(dst, src, num_remain); + } + + // 1.3. Refill the remaining buffer. + if (num_refill) { + auto src = static_cast(const_cast(chunk.c_str())); + auto dst = buffer + offset + num_remain; + memcpy(dst, src, num_refill); + } + + // 1.4. Set the file pointer to the new offset + sf->tell_off = offset; + fseek((FILE*)sf->fp, offset, SEEK_SET); + + // 2. Perform decoding operation + // The following part is practically same as "input" effect + // https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/input.c#L30-L48 + + // Ensure that it's a multiple of the number of channels + *osamp -= *osamp % effp->out_signal.channels; + + // Read up to *osamp samples into obuf; + // store the actual number read back to *osamp + *osamp = sox_read(sf, obuf, *osamp); + + // Decoding is finished when fileobject is exhausted and sox can no longer + // decode a sample. + return (priv->eof_reached && !*osamp) ? SOX_EOF : SOX_SUCCESS; +} + +auto fileobj_output_flow( + sox_effect_t* effp, + sox_sample_t const* ibuf, + sox_sample_t* obuf LSX_UNUSED, + size_t* isamp, + size_t* osamp) -> int { + *osamp = 0; + if (*isamp) { + auto priv = static_cast(effp->priv); + auto sf = priv->sf; + auto fp = static_cast(sf->fp); + auto fileobj = priv->fileobj; + auto buffer = priv->buffer; + + // Encode chunk + auto num_samples_written = sox_write(sf, ibuf, *isamp); + fflush(fp); + + // Copy the encoded chunk to python object. + fileobj->attr("write")(py::bytes(*buffer, ftell(fp))); + + // Reset FILE* + sf->tell_off = 0; + fseek(fp, 0, SEEK_SET); + + if (num_samples_written != *isamp) { + if (sf->sox_errno) { + std::ostringstream stream; + stream << sf->sox_errstr << " " << sox_strerror(sf->sox_errno) << " " + << sf->filename; + throw std::runtime_error(stream.str()); + } + return SOX_EOF; + } + } + return SOX_SUCCESS; +} + +auto get_fileobj_input_handler() -> sox_effect_handler_t* { + static sox_effect_handler_t handler{ + /*name=*/"input_fileobj_object", + /*usage=*/nullptr, + /*flags=*/SOX_EFF_MCHAN, + /*getopts=*/nullptr, + /*start=*/nullptr, + /*flow=*/nullptr, + /*drain=*/fileobj_input_drain, + /*stop=*/nullptr, + /*kill=*/nullptr, + /*priv_size=*/sizeof(FileObjInputPriv)}; + return &handler; +} + +auto get_fileobj_output_handler() -> sox_effect_handler_t* { + static sox_effect_handler_t handler{ + /*name=*/"output_fileobj_object", + /*usage=*/nullptr, + /*flags=*/SOX_EFF_MCHAN, + /*getopts=*/nullptr, + /*start=*/nullptr, + /*flow=*/fileobj_output_flow, + /*drain=*/nullptr, + /*stop=*/nullptr, + /*kill=*/nullptr, + /*priv_size=*/sizeof(FileObjOutputPriv)}; + return &handler; +} + +} // namespace + +void SoxEffectsChainPyBind::addInputFileObj( + sox_format_t* sf, + char* buffer, + uint64_t buffer_size, + py::object* fileobj) { + in_sig_ = sf->signal; + interm_sig_ = in_sig_; + + SoxEffect e(sox_create_effect(get_fileobj_input_handler())); + auto priv = static_cast(e->priv); + priv->sf = sf; + priv->fileobj = fileobj; + priv->eof_reached = false; + priv->buffer = buffer; + priv->buffer_size = buffer_size; + if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { + throw std::runtime_error( + "Internal Error: Failed to add effect: input fileobj"); + } +} + +void SoxEffectsChainPyBind::addOutputFileObj( + sox_format_t* sf, + char** buffer, + size_t* buffer_size, + py::object* fileobj) { + out_sig_ = sf->signal; + SoxEffect e(sox_create_effect(get_fileobj_output_handler())); + auto priv = static_cast(e->priv); + priv->sf = sf; + priv->fileobj = fileobj; + priv->buffer = buffer; + priv->buffer_size = buffer_size; + if (sox_add_effect(sec_, e, &interm_sig_, &out_sig_) != SOX_SUCCESS) { + throw std::runtime_error( + "Internal Error: Failed to add effect: output fileobj"); + } +} + +} // namespace torchaudio::sox_effects_chain diff --git a/torchaudio/csrc/pybind/sox/effects_chain.h b/torchaudio/csrc/pybind/sox/effects_chain.h new file mode 100644 index 0000000000000000000000000000000000000000..7e3c0267dbd8d313b81ed916ba43be7bef3ac034 --- /dev/null +++ b/torchaudio/csrc/pybind/sox/effects_chain.h @@ -0,0 +1,28 @@ +#ifndef TORCHAUDIO_PYBIND_SOX_EFFECTS_CHAIN_H +#define TORCHAUDIO_PYBIND_SOX_EFFECTS_CHAIN_H + +#include +#include + +namespace torchaudio::sox_effects_chain { + +class SoxEffectsChainPyBind : public SoxEffectsChain { + using SoxEffectsChain::SoxEffectsChain; + + public: + void addInputFileObj( + sox_format_t* sf, + char* buffer, + uint64_t buffer_size, + py::object* fileobj); + + void addOutputFileObj( + sox_format_t* sf, + char** buffer, + size_t* buffer_size, + py::object* fileobj); +}; + +} // namespace torchaudio::sox_effects_chain + +#endif diff --git a/torchaudio/csrc/pybind/sox/io.cpp b/torchaudio/csrc/pybind/sox/io.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2935575b38e79aa97da40b10addc7d7af9373595 --- /dev/null +++ b/torchaudio/csrc/pybind/sox/io.cpp @@ -0,0 +1,190 @@ +#include +#include +#include +#include +#include +#include + +#include + +using namespace torchaudio::sox_utils; + +namespace torchaudio::sox_io { + +auto get_info_fileobj(py::object fileobj, c10::optional format) + -> std::tuple { + // Prepare in-memory file object + // When libsox opens a file, it also reads the header. + // When opening a file there are two functions that might touch FILE* (and the + // underlying buffer). + // * `auto_detect_format` + // https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L43 + // * `startread` handler of detected format. + // https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L574 + // To see the handler of a particular format, go to + // https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/.c + // For example, voribs can be found + // https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/vorbis.c#L97-L158 + // + // `auto_detect_format` function only requires 256 bytes, but format-dependent + // `startread` handler might require more data. In case of vorbis, the size of + // header is unbounded, but typically 4kB maximum. + // + // "The header size is unbounded, although for streaming a rule-of-thumb of + // 4kB or less is recommended (and Xiph.Org's Vorbis encoder follows this + // suggestion)." + // + // See: + // https://xiph.org/vorbis/doc/Vorbis_I_spec.html + const auto capacity = [&]() { + // NOTE: + // Use the abstraction provided by `libtorchaudio` to access the global + // config defined by libsox. Directly using `sox_get_globals` function will + // end up retrieving the static variable defined in `_torchaudio`, which is + // not correct. + const auto bufsiz = get_buffer_size(); + const int64_t kDefaultCapacityInBytes = 4096; + return (bufsiz > kDefaultCapacityInBytes) ? bufsiz + : kDefaultCapacityInBytes; + }(); + std::string buffer(capacity, '\0'); + auto* buf = const_cast(buffer.data()); + auto num_read = read_fileobj(&fileobj, capacity, buf); + // If the file is shorter than 256, then libsox cannot read the header. + auto buf_size = (num_read > 256) ? num_read : 256; + + SoxFormat sf(sox_open_mem_read( + buf, + buf_size, + /*signal=*/nullptr, + /*encoding=*/nullptr, + /*filetype=*/format.has_value() ? format.value().c_str() : nullptr)); + + // In case of streamed data, length can be 0 + validate_input_memfile(sf); + + return std::make_tuple( + static_cast(sf->signal.rate), + static_cast(sf->signal.length / sf->signal.channels), + static_cast(sf->signal.channels), + static_cast(sf->encoding.bits_per_sample), + get_encoding(sf->encoding.encoding)); +} + +auto load_audio_fileobj( + py::object fileobj, + c10::optional frame_offset, + c10::optional num_frames, + c10::optional normalize, + c10::optional channels_first, + c10::optional format) -> std::tuple { + auto effects = get_effects(frame_offset, num_frames); + return torchaudio::sox_effects::apply_effects_fileobj( + std::move(fileobj), + effects, + normalize, + channels_first, + std::move(format)); +} + +namespace { + +// helper class to automatically release buffer, to be used by +// save_audio_fileobj +struct AutoReleaseBuffer { + char* ptr; + size_t size; + + AutoReleaseBuffer() : ptr(nullptr), size(0) {} + AutoReleaseBuffer(const AutoReleaseBuffer& other) = delete; + AutoReleaseBuffer(AutoReleaseBuffer&& other) = delete; + auto operator=(const AutoReleaseBuffer& other) -> AutoReleaseBuffer& = delete; + auto operator=(AutoReleaseBuffer&& other) -> AutoReleaseBuffer& = delete; + ~AutoReleaseBuffer() { + if (ptr) { + free(ptr); + } + } +}; + +} // namespace + +void save_audio_fileobj( + py::object fileobj, + torch::Tensor tensor, + int64_t sample_rate, + bool channels_first, + c10::optional compression, + c10::optional format, + c10::optional encoding, + c10::optional bits_per_sample) { + validate_input_tensor(tensor); + + if (!format.has_value()) { + throw std::runtime_error( + "`format` is required when saving to file object."); + } + const auto filetype = format.value(); + + if (filetype == "amr-nb") { + const auto num_channels = tensor.size(channels_first ? 0 : 1); + if (num_channels != 1) { + throw std::runtime_error( + "amr-nb format only supports single channel audio."); + } + } else if (filetype == "htk") { + const auto num_channels = tensor.size(channels_first ? 0 : 1); + if (num_channels != 1) { + throw std::runtime_error( + "htk format only supports single channel audio."); + } + } else if (filetype == "gsm") { + const auto num_channels = tensor.size(channels_first ? 0 : 1); + if (num_channels != 1) { + throw std::runtime_error( + "gsm format only supports single channel audio."); + } + if (sample_rate != 8000) { + throw std::runtime_error( + "gsm format only supports a sampling rate of 8kHz."); + } + } + const auto signal_info = + get_signalinfo(&tensor, sample_rate, filetype, channels_first); + const auto encoding_info = get_encodinginfo_for_save( + filetype, + tensor.dtype(), + compression, + std::move(encoding), + bits_per_sample); + + AutoReleaseBuffer buffer; + + SoxFormat sf(sox_open_memstream_write( + &buffer.ptr, + &buffer.size, + &signal_info, + &encoding_info, + filetype.c_str(), + /*oob=*/nullptr)); + + if (static_cast(sf) == nullptr) { + throw std::runtime_error( + "Error saving audio file: failed to open memory stream."); + } + + torchaudio::sox_effects_chain::SoxEffectsChainPyBind chain( + /*input_encoding=*/get_tensor_encodinginfo(tensor.dtype()), + /*output_encoding=*/sf->encoding); + chain.addInputTensor(&tensor, sample_rate, channels_first); + chain.addOutputFileObj(sf, &buffer.ptr, &buffer.size, &fileobj); + chain.run(); + + // Closing the sox_format_t is necessary for flushing the last chunk to the + // buffer + sf.close(); + + fileobj.attr("write")(py::bytes(buffer.ptr, buffer.size)); +} + +} // namespace torchaudio::sox_io diff --git a/torchaudio/csrc/pybind/sox/io.h b/torchaudio/csrc/pybind/sox/io.h new file mode 100644 index 0000000000000000000000000000000000000000..4059bdc356960f3e01c2c2dba0e79cded401caa8 --- /dev/null +++ b/torchaudio/csrc/pybind/sox/io.h @@ -0,0 +1,31 @@ +#ifndef TORCHAUDIO_PYBIND_SOX_IO_H +#define TORCHAUDIO_PYBIND_SOX_IO_H + +#include + +namespace torchaudio::sox_io { + +auto get_info_fileobj(py::object fileobj, c10::optional format) + -> std::tuple; + +auto load_audio_fileobj( + py::object fileobj, + c10::optional frame_offset, + c10::optional num_frames, + c10::optional normalize, + c10::optional channels_first, + c10::optional format) -> std::tuple; + +void save_audio_fileobj( + py::object fileobj, + torch::Tensor tensor, + int64_t sample_rate, + bool channels_first, + c10::optional compression, + c10::optional format, + c10::optional encoding, + c10::optional bits_per_sample); + +} // namespace torchaudio::sox_io + +#endif diff --git a/torchaudio/csrc/pybind/sox/utils.cpp b/torchaudio/csrc/pybind/sox/utils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..39669b02506a00512268ecb2469c8ec555689269 --- /dev/null +++ b/torchaudio/csrc/pybind/sox/utils.cpp @@ -0,0 +1,31 @@ +#include + +namespace torchaudio::sox_utils { + +auto read_fileobj(py::object* fileobj, const uint64_t size, char* buffer) + -> uint64_t { + uint64_t num_read = 0; + while (num_read < size) { + auto request = size - num_read; + auto chunk = static_cast( + static_cast(fileobj->attr("read")(request))); + auto chunk_len = chunk.length(); + if (chunk_len == 0) { + break; + } + if (chunk_len > request) { + std::ostringstream message; + message + << "Requested up to " << request << " bytes but, " + << "received " << chunk_len << " bytes. " + << "The given object does not confirm to read protocol of file object."; + throw std::runtime_error(message.str()); + } + memcpy(buffer, chunk.data(), chunk_len); + buffer += chunk_len; + num_read += chunk_len; + } + return num_read; +} + +} // namespace torchaudio::sox_utils diff --git a/torchaudio/csrc/pybind/sox/utils.h b/torchaudio/csrc/pybind/sox/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..f13944f5f1638efbbe7b011d61882f87a1cbd3a5 --- /dev/null +++ b/torchaudio/csrc/pybind/sox/utils.h @@ -0,0 +1,12 @@ +#ifndef TORCHAUDIO_PYBIND_SOX_UTILS_H +#define TORCHAUDIO_PYBIND_SOX_UTILS_H + +#include + +namespace torchaudio::sox_utils { + +auto read_fileobj(py::object* fileobj, uint64_t size, char* buffer) -> uint64_t; + +} // namespace torchaudio::sox_utils + +#endif diff --git a/torchaudio/csrc/rnnt/autograd.cpp b/torchaudio/csrc/rnnt/autograd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4fa6ba7996ce76eefa587a50894ae14cadef825b --- /dev/null +++ b/torchaudio/csrc/rnnt/autograd.cpp @@ -0,0 +1,56 @@ +#include +#include + +namespace torchaudio { +namespace rnnt { + +class RNNTLossFunction : public torch::autograd::Function { + public: + static torch::autograd::tensor_list forward( + torch::autograd::AutogradContext* ctx, + torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& logit_lengths, + const torch::Tensor& target_lengths, + int64_t blank, + double clamp) { + torch::Tensor undef; + auto result = + rnnt_loss(logits, targets, logit_lengths, target_lengths, blank, clamp); + auto costs = std::get<0>(result); + auto grads = std::get<1>(result).value_or(undef); + ctx->save_for_backward({grads}); + return {costs, grads}; + } + + static torch::autograd::tensor_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + auto grad = saved[0]; + auto grad_out = grad_outputs[0].view({-1, 1, 1, 1}); + auto result = grad * grad_out; + torch::Tensor undef; + return {result, undef, undef, undef, undef, undef, undef, undef}; + } +}; + +std::tuple> rnnt_loss_autograd( + torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& logit_lengths, + const torch::Tensor& target_lengths, + int64_t blank, + double clamp) { + at::AutoDispatchBelowADInplaceOrView guard; + auto results = RNNTLossFunction::apply( + logits, targets, logit_lengths, target_lengths, blank, clamp); + return std::make_tuple(results[0], results[1]); +} + +TORCH_LIBRARY_IMPL(torchaudio, Autograd, m) { + m.impl("rnnt_loss", rnnt_loss_autograd); +} + +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/compute.cpp b/torchaudio/csrc/rnnt/compute.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9c3bf84a055e1741b8f41ba70a16a7f505dae3aa --- /dev/null +++ b/torchaudio/csrc/rnnt/compute.cpp @@ -0,0 +1,25 @@ +#include +#include + +std::tuple> rnnt_loss( + torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& logit_lengths, + const torch::Tensor& target_lengths, + int64_t blank, + double clamp) { + static auto op = torch::Dispatcher::singleton() + .findSchemaOrThrow("torchaudio::rnnt_loss", "") + .typed(); + return op.call(logits, targets, logit_lengths, target_lengths, blank, clamp); +} + +TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def( + "rnnt_loss(Tensor logits," + "Tensor targets," + "Tensor logit_lengths," + "Tensor target_lengths," + "int blank," + "float clamp) -> (Tensor, Tensor?)"); +} diff --git a/torchaudio/csrc/rnnt/compute.h b/torchaudio/csrc/rnnt/compute.h new file mode 100644 index 0000000000000000000000000000000000000000..eea16a5feed8c7a764cda9b27e336f8887c951ff --- /dev/null +++ b/torchaudio/csrc/rnnt/compute.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +std::tuple> rnnt_loss( + torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& logit_lengths, + const torch::Tensor& target_lengths, + int64_t blank, + double clamp); diff --git a/torchaudio/csrc/rnnt/compute_alphas.cpp b/torchaudio/csrc/rnnt/compute_alphas.cpp new file mode 100644 index 0000000000000000000000000000000000000000..adbcc1c8e7401d015527e2a7fbeb5161ff4ee94a --- /dev/null +++ b/torchaudio/csrc/rnnt/compute_alphas.cpp @@ -0,0 +1,11 @@ +#include + +TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def( + "rnnt_loss_alphas(Tensor logits," + "Tensor targets," + "Tensor logit_lengths," + "Tensor target_lengths," + "int blank," + "float clamp) -> Tensor"); +} diff --git a/torchaudio/csrc/rnnt/compute_betas.cpp b/torchaudio/csrc/rnnt/compute_betas.cpp new file mode 100644 index 0000000000000000000000000000000000000000..772883813770d74ebc39f2bc9ab59c602127b339 --- /dev/null +++ b/torchaudio/csrc/rnnt/compute_betas.cpp @@ -0,0 +1,11 @@ +#include + +TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def( + "rnnt_loss_betas(Tensor logits," + "Tensor targets," + "Tensor logit_lengths," + "Tensor target_lengths," + "int blank," + "float clamp) -> Tensor"); +} diff --git a/torchaudio/csrc/rnnt/cpu/compute.cpp b/torchaudio/csrc/rnnt/cpu/compute.cpp new file mode 100644 index 0000000000000000000000000000000000000000..088f68be528d503f9a4d3a6ce0fbfa2d8d40c270 --- /dev/null +++ b/torchaudio/csrc/rnnt/cpu/compute.cpp @@ -0,0 +1,148 @@ +#include +#include + +namespace torchaudio { +namespace rnnt { +namespace cpu { + +// Entry point into RNNT Loss +std::tuple> compute( + torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& logit_lengths, + const torch::Tensor& target_lengths, + int64_t blank, + double clamp) { + TORCH_CHECK( + logits.device().type() == targets.device().type(), + "logits and targets must be on the same device"); + TORCH_CHECK( + logits.device().type() == logit_lengths.device().type(), + "logits and logit_lengths must be on the same device"); + TORCH_CHECK( + logits.device().type() == target_lengths.device().type(), + "logits and target_lengths must be on the same device"); + + TORCH_CHECK( + logits.dtype() == torch::kFloat32 || logits.dtype() == torch::kFloat16, + "logits must be float32 or float16 (half) type"); + TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type"); + TORCH_CHECK( + logit_lengths.dtype() == torch::kInt32, + "logit_lengths must be int32 type"); + TORCH_CHECK( + target_lengths.dtype() == torch::kInt32, + "target_lengths must be int32 type"); + + TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous"); + TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); + TORCH_CHECK( + logit_lengths.is_contiguous(), "logit_lengths must be contiguous"); + TORCH_CHECK( + target_lengths.is_contiguous(), "target_lengths must be contiguous"); + + TORCH_CHECK( + logits.dim() == 4, "logits must be 4-D (batch, time, target, class)"); + TORCH_CHECK( + targets.dim() == 2, "targets must be 2-D (batch, max target length)"); + TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D"); + TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D"); + + TORCH_CHECK( + logit_lengths.size(0) == logits.size(0), + "batch dimension mismatch between logits and logit_lengths"); + TORCH_CHECK( + target_lengths.size(0) == logits.size(0), + "batch dimension mismatch between logits and target_lengths"); + TORCH_CHECK( + targets.size(0) == logits.size(0), + "batch dimension mismatch between logits and targets"); + + TORCH_CHECK( + blank >= 0 && blank < logits.size(-1), + "blank must be within [0, logits.shape[-1])"); + + TORCH_CHECK( + logits.size(1) == at::max(logit_lengths).item().toInt(), + "input length mismatch"); + TORCH_CHECK( + logits.size(2) == at::max(target_lengths).item().toInt() + 1, + "output length mismatch"); + TORCH_CHECK( + targets.size(1) == at::max(target_lengths).item().toInt(), + "target length mismatch"); + + Options options; + options.batchSize_ = logit_lengths.size(0); + options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); + options.maxSrcLen_ = logits.size(1); + options.maxTgtLen_ = logits.size(2); + options.numTargets_ = logits.size(3); + options.blank_ = blank; + options.clamp_ = clamp; + + CHECK_EQ(logits.device().type(), torch::DeviceType::CPU); + options.device_ = CPU; + + torch::Tensor costs = torch::empty( + options.batchSize_ * options.nHypos_, + torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + c10::optional gradients = torch::zeros_like(logits); + + torch::Tensor int_workspace = torch::empty( + IntWorkspace::ComputeSizeFromOptions(options), + torch::TensorOptions() + .device(logits.device()) + .dtype(torch::ScalarType::Int)); + + torch::Tensor float_workspace = torch::empty( + DtypeWorkspace::ComputeSizeFromOptions(options), + torch::TensorOptions() + .device(logits.device()) + .dtype(torch::ScalarType::Float)); + + Workspace workspace( + /*options=*/options, + /*dtype_data=*/float_workspace.data_ptr(), + /*dtype_size=*/float_workspace.numel(), + /*int_data=*/int_workspace.data_ptr(), + /*int_size=*/int_workspace.numel()); + + switch (logits.scalar_type()) { + case torch::ScalarType::Float: { + Compute( + /*workspace=*/workspace, + /*logits=*/logits.data_ptr(), + /*targets=*/targets.data_ptr(), + /*logit_lengths=*/logit_lengths.data_ptr(), + /*target_lengths=*/target_lengths.data_ptr(), + /*costs=*/costs.data_ptr(), + /*gradients=*/gradients->data_ptr()); + break; + } + case torch::ScalarType::Half: { + Compute( + /*workspace=*/workspace, + /*logits=*/logits.data_ptr(), + /*targets=*/targets.data_ptr(), + /*logit_lengths=*/logit_lengths.data_ptr(), + /*target_lengths=*/target_lengths.data_ptr(), + /*costs=*/costs.data_ptr(), + /*gradients=*/gradients->data_ptr()); + break; + } + default: { + break; + } + }; + + return std::make_tuple(costs, gradients); +} + +TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("rnnt_loss", &compute); +} + +} // namespace cpu +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/cpu/compute_alphas.cpp b/torchaudio/csrc/rnnt/cpu/compute_alphas.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6a1fb0f881d75941b0f10cc26ee2f0c55b797eed --- /dev/null +++ b/torchaudio/csrc/rnnt/cpu/compute_alphas.cpp @@ -0,0 +1,70 @@ +#include +#include + +namespace torchaudio { +namespace rnnt { +namespace cpu { + +torch::Tensor compute_alphas( + const torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& logit_lengths, + const torch::Tensor& target_lengths, + int64_t blank, + double clamp) { + Options options; + options.batchSize_ = logit_lengths.size(0); + options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); + options.maxSrcLen_ = logits.size(1); + options.maxTgtLen_ = logits.size(2); + options.numTargets_ = logits.size(3); + options.blank_ = blank; + options.clamp_ = clamp; + + CHECK_EQ(logits.device().type(), torch::DeviceType::CPU); + options.device_ = CPU; + + torch::Tensor alphas = torch::zeros( + {options.batchSize_ * options.nHypos_, + options.maxSrcLen_, + options.maxTgtLen_}, + torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + + torch::Tensor int_workspace = torch::empty( + IntWorkspace::ComputeSizeFromOptions(options), + torch::TensorOptions() + .device(logits.device()) + .dtype(torch::ScalarType::Int)); + + torch::Tensor float_workspace = torch::empty( + DtypeWorkspace::ComputeSizeFromOptions(options), + torch::TensorOptions() + .device(logits.device()) + .dtype(torch::ScalarType::Float)); + + Workspace workspace( + /*options=*/options, + /*dtype_data=*/float_workspace.data_ptr(), + /*dtype_size=*/float_workspace.numel(), + /*int_data=*/int_workspace.data_ptr(), + /*int_size=*/int_workspace.numel()); + + // Only support float, this is mainly to enable easy + // unit-testing + ComputeAlphas( + /*workspace=*/workspace, + /*logits=*/logits.data_ptr(), + /*targets=*/targets.data_ptr(), + /*logit_lengths=*/logit_lengths.data_ptr(), + /*target_lengths=*/target_lengths.data_ptr(), + /*alphas=*/alphas.data_ptr()); + return alphas; +} + +TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("rnnt_loss_alphas", &compute_alphas); +} + +} // namespace cpu +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/cpu/compute_betas.cpp b/torchaudio/csrc/rnnt/cpu/compute_betas.cpp new file mode 100644 index 0000000000000000000000000000000000000000..51e738d8b46bcb7986dd5cf76fa276c51e2a8de2 --- /dev/null +++ b/torchaudio/csrc/rnnt/cpu/compute_betas.cpp @@ -0,0 +1,75 @@ +#include +#include + +namespace torchaudio { +namespace rnnt { +namespace cpu { + +torch::Tensor compute_betas( + const torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& logit_lengths, + const torch::Tensor& target_lengths, + int64_t blank, + double clamp) { + Options options; + options.batchSize_ = logit_lengths.size(0); + options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); + options.maxSrcLen_ = logits.size(1); + options.maxTgtLen_ = logits.size(2); + options.numTargets_ = logits.size(3); + options.blank_ = blank; + options.clamp_ = clamp; + + CHECK_EQ(logits.device().type(), torch::DeviceType::CPU); + options.device_ = CPU; + + torch::Tensor costs = torch::empty( + target_lengths.size(0), + torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + + torch::Tensor betas = torch::zeros( + {options.batchSize_ * options.nHypos_, + options.maxSrcLen_, + options.maxTgtLen_}, + torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + + torch::Tensor int_workspace = torch::empty( + IntWorkspace::ComputeSizeFromOptions(options), + torch::TensorOptions() + .device(logits.device()) + .dtype(torch::ScalarType::Int)); + + torch::Tensor float_workspace = torch::empty( + DtypeWorkspace::ComputeSizeFromOptions(options), + torch::TensorOptions() + .device(logits.device()) + .dtype(torch::ScalarType::Float)); + + Workspace workspace( + /*options=*/options, + /*dtype_data=*/float_workspace.data_ptr(), + /*dtype_size=*/float_workspace.numel(), + /*int_data=*/int_workspace.data_ptr(), + /*int_size=*/int_workspace.numel()); + + // Only support float, this is mainly to enable easy + // unit-testing + ComputeBetas( + /*workspace=*/workspace, + /*logits=*/logits.data_ptr(), + /*targets=*/targets.data_ptr(), + /*logit_lengths=*/logit_lengths.data_ptr(), + /*target_lengths=*/target_lengths.data_ptr(), + /*costs=*/costs.data_ptr(), + /*betas=*/betas.data_ptr()); + return betas; +} + +TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("rnnt_loss_betas", &compute_betas); +} + +} // namespace cpu +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/cpu/cpu_kernels.h b/torchaudio/csrc/rnnt/cpu/cpu_kernels.h new file mode 100644 index 0000000000000000000000000000000000000000..468cb41887bc0b5a66dfc2cf8bbaaaebdd99e12c --- /dev/null +++ b/torchaudio/csrc/rnnt/cpu/cpu_kernels.h @@ -0,0 +1,498 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include + +namespace torchaudio { +namespace rnnt { +namespace cpu { + +template +struct LogProbs { + DTYPE skip_; // blank. + DTYPE emit_; // target. + + LogProbs(DTYPE skip, DTYPE emit) : skip_(skip), emit_(emit) {} + + DTYPE& skip() { + return skip_; + } + DTYPE& emit() { + return emit_; + } + + const DTYPE& skip() const { + return skip_; + } + const DTYPE& emit() const { + return emit_; + } +}; + +// TensorView: view a block of allocated memory as a tensor. +template +class TensorView { + public: + TensorView(const std::vector& dims, DTYPE* data) + : dims_(dims), data_(data) { + strides_.resize(dims.size()); + strides_.back() = 1; + for (int i = dims.size() - 2; i >= 0; --i) { + strides_[i] = strides_[i + 1] * dims[i + 1]; + } + } + + DTYPE& operator()(const std::vector& indices) { + CHECK_EQ(indices.size(), dims_.size()); + int index = indices.back(); + for (int i = indices.size() - 2; i >= 0; --i) { + index += indices[i] * strides_[i]; + } + return data_[index]; + } + + void SetZero() { + int size = dims_[0] * strides_[0]; + std::memset(data_, 0, sizeof(DTYPE) * size); + } + + private: + std::vector dims_; + std::vector strides_; + DTYPE* data_; +}; + +template +status_t LogSumExp2D(int N, int D, const DTYPE* logits, CAST_DTYPE* outputs) { + for (int i = 0; i < N * D; i += D) { + CAST_DTYPE max = logits[i]; + for (int j = 1; j < D; ++j) { + max = std::max(max, CAST_DTYPE(logits[i + j])); + } + CAST_DTYPE sum = 0; + for (int j = 0; j < D; ++j) { + sum = sum + std::exp(CAST_DTYPE(logits[i + j]) - max); + } + outputs[i / D] = max + std::log(sum); + } + + return SUCCESS; +} + +template +void ComputeLogProbsOneSequence( + const Options& options, + TensorView& logits, + const int* targets, + int srcLen, + int tgtLen, + TensorView& denom, + TensorView>& logProbs) { + const int& T = srcLen; + const int& U = tgtLen; + const int& blank = options.blank_; + + for (int t = 0; t < T; ++t) { + for (int u = 0; u < U; ++u) { + if (u < U - 1) { + logProbs({t, u}).emit() = + CAST_DTYPE(logits({t, u, targets[u]})) - denom({t, u}); + } + logProbs({t, u}).skip() = + CAST_DTYPE(logits({t, u, blank})) - denom({t, u}); + } + } +} + +template +status_t ComputeLogProbs( + const Options& options, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + const CAST_DTYPE* denominators, + CAST_DTYPE* logProbs) { + std::vector> seqLogits; + std::vector seqTargets; + std::vector> seqDenoms; + std::vector>> seqlogProbs; + + const int& B = options.batchSize_; + const int& maxT = options.maxSrcLen_; + const int& maxU = options.maxTgtLen_; + const int& D = options.numTargets_; + for (int b = 0; b < B; ++b) { + seqLogits.push_back( + TensorView({maxT, maxU, D}, logits + b * maxT * maxU * D)); + seqTargets.push_back(targets + b * (maxU - 1)); + seqDenoms.push_back(TensorView( + {maxT, maxU}, denominators + b * maxT * maxU)); + seqlogProbs.push_back(TensorView>( + {maxT, maxU}, + reinterpret_cast*>(logProbs) + b * maxT * maxU)); + } + + //#pragma omp parallel for + for (int b = 0; b < B; ++b) { // use max 2 * B threads. + ComputeLogProbsOneSequence( + /*options=*/options, + /*logits=*/seqLogits[b], + /*targets=*/seqTargets[b], + /*srcLen=*/srcLengths[b], + /*tgtLen=*/tgtLengths[b] + 1, // with prepended blank. + /*denom=*/seqDenoms[b], + /*logProbs=*/seqlogProbs[b]); + } + + return SUCCESS; +} + +template +DTYPE ComputeAlphaOneSequence( + const Options& options, + TensorView>& logProbs, + int srcLen, + int tgtLen, + TensorView& alpha) { + const int& T = srcLen; + const int& U = tgtLen; + + alpha({0, 0}) = DTYPE(0); + + for (int t = 1; t < T; ++t) { // u == 0. + alpha({t, 0}) = alpha({t - 1, 0}) + logProbs({t - 1, 0}).skip(); + } + + for (int u = 1; u < U; ++u) { // t == 0. + alpha({0, u}) = alpha({0, u - 1}) + logProbs({0, u - 1}).emit(); + } + + for (int t = 1; t < T; ++t) { + for (int u = 1; u < U; ++u) { + alpha({t, u}) = math::lse( + alpha({t - 1, u}) + logProbs({t - 1, u}).skip(), + alpha({t, u - 1}) + logProbs({t, u - 1}).emit()); + } + } + + DTYPE forward_score = alpha({T - 1, U - 1}) + logProbs({T - 1, U - 1}).skip(); + + return forward_score; +} + +template +DTYPE ComputeBetaOneSequence( + const Options& options, + TensorView>& logProbs, + int srcLen, + int tgtLen, + TensorView& beta) { + const int& T = srcLen; + const int& U = tgtLen; + + beta({T - 1, U - 1}) = logProbs({T - 1, U - 1}).skip(); + + for (int t = T - 2; t >= 0; --t) { // u == U - 1. + beta({t, U - 1}) = beta({t + 1, U - 1}) + logProbs({t, U - 1}).skip(); + } + + for (int u = U - 2; u >= 0; --u) { // t == T - 1. + beta({T - 1, u}) = beta({T - 1, u + 1}) + logProbs({T - 1, u}).emit(); + } + + for (int t = T - 2; t >= 0; --t) { + for (int u = U - 2; u >= 0; --u) { + beta({t, u}) = math::lse( + beta({t + 1, u}) + logProbs({t, u}).skip(), + beta({t, u + 1}) + logProbs({t, u}).emit()); + } + } + + DTYPE backward_score = beta({0, 0}); + + return backward_score; +} + +template +DTYPE ComputeAlphaOrBetaOneSequence( + int thread, + const Options& options, + TensorView>& logProbs, + int srcLen, + int tgtLen, + TensorView& alpha, + TensorView& beta) { + if (thread & 1) { + return ComputeAlphaOneSequence( + /*options=*/options, + /*logProbs=*/logProbs, + /*srcLen=*/srcLen, + /*tgtLen=*/tgtLen, + /*alpha=*/alpha); + } else { + return ComputeBetaOneSequence( + /*options=*/options, + /*logProbs=*/logProbs, + /*srcLen=*/srcLen, + /*tgtLen=*/tgtLen, + /*beta=*/beta); + } +} + +template +void ComputeAlphasBetas( + const Options& options, + const CAST_DTYPE* logProbs, + const int* srcLengths, + const int* tgtLengths, + CAST_DTYPE* alphas, + CAST_DTYPE* betas, + DTYPE* costs) { + std::vector>> seqlogProbs; + std::vector> seq_alphas; + std::vector> seq_betas; + + const int& B = options.batchSize_; + const int& maxT = options.maxSrcLen_; + const int& maxU = options.maxTgtLen_; + + for (int b = 0; b < B; ++b) { + seqlogProbs.push_back(TensorView>( + {maxT, maxU}, + reinterpret_cast*>( + const_cast(logProbs)) + + b * maxT * maxU)); + seq_alphas.push_back( + TensorView({maxT, maxU}, alphas + b * maxT * maxU)); + seq_betas.push_back( + TensorView({maxT, maxU}, betas + b * maxT * maxU)); + } + + std::vector scores(B << 1); + //#pragma omp parallel for + for (int t = 0; t < (B << 1); ++t) { // use max 2 * B threads. + int i = (t >> 1); + scores[t] = ComputeAlphaOrBetaOneSequence( + /*thread=*/t, + /*options=*/options, + /*logProbs=*/seqlogProbs[i], + /*srcLen=*/srcLengths[i], + /*tgtLen=*/tgtLengths[i] + 1, // with prepended blank. + /*alpha=*/seq_alphas[i], + /*beta=*/seq_betas[i]); + } + for (int b = 0; b < B; ++b) { + costs[b] = -scores[b << 1]; + } +} + +template +void ComputeGradientsOneSequence( + const Options& options, + TensorView& logits, + const int* targets, + int srcLen, + int tgtLen, + TensorView& denom, + TensorView& alpha, + TensorView& beta, + TensorView& gradients) { + // don't set gradients to zero to here as gradients might reuse memory from + // logits + + const int& T = srcLen; + const int& U = tgtLen; + const int& D = options.numTargets_; + const int& blank = options.blank_; + const CAST_DTYPE clamp = options.clamp_; + + CAST_DTYPE cost = -beta({0, 0}); + + // Note - below gradient is different from numpy_transducer, since we + // compute log_softmax more efficiently within the loss, to save memory The + // details of the below implementation / equations can be found in Sec 3.2 + // (function merging) in below paper: + // https://www.microsoft.com/en-us/research/uploads/prod/2019/10/RNNT.pdf + + for (int t = 0; t < T; ++t) { + for (int u = 0; u < U; ++u) { + CAST_DTYPE c = alpha({t, u}) + cost - denom({t, u}); + for (int d = 0; d < D; ++d) { + CAST_DTYPE g = CAST_DTYPE(logits({t, u, d})) + c; + if (d == blank && t == T - 1 && u == U - 1) { // last blank transition. + gradients({t, u, d}) = std::exp(g + beta({t, u})) - std::exp(g); + } else if (d == blank && t < T - 1) { + gradients({t, u, d}) = + std::exp(g + beta({t, u})) - std::exp(g + beta({t + 1, u})); + } else if (u < U - 1 && d == targets[u]) { + gradients({t, u, d}) = + std::exp(g + beta({t, u})) - std::exp(g + beta({t, u + 1})); + } else { + gradients({t, u, d}) = std::exp(g + beta({t, u})); + } + + if (clamp > 0) { + gradients({t, u, d}) = + math::min(CAST_DTYPE(gradients({t, u, d})), clamp); + gradients({t, u, d}) = + math::max(CAST_DTYPE(gradients({t, u, d})), -clamp); + } + } + } + } + + // zero out the rest of the gradients, necessary when reusing logits memory + // check the memory location to see if it's necessary + if (&gradients({0, 0, 0}) == &logits({0, 0, 0})) { + const int& maxT = options.maxSrcLen_; + const int& maxU = options.maxTgtLen_; + for (int t = T; t < maxT; ++t) { + for (int u = 0; u < maxU; ++u) { + for (int d = 0; d < D; ++d) { + gradients({t, u, d}) = 0.; + } + } + } + for (int t = 0; t < T; ++t) { + for (int u = U; u < maxU; ++u) { + for (int d = 0; d < D; ++d) { + gradients({t, u, d}) = 0.; + } + } + } + } +} + +template +void ComputeGradients( + const Options& options, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + const CAST_DTYPE* denominators, + const CAST_DTYPE* alphas, + const CAST_DTYPE* betas, + DTYPE* gradients) { + std::vector> seqLogits; + std::vector seqTargets; + std::vector> seqDenoms; + std::vector> seq_alphas; + std::vector> seq_betas; + std::vector> seq_gradients; + + const int& B = options.batchSize_; + const int& maxT = options.maxSrcLen_; + const int& maxU = options.maxTgtLen_; + const int& D = options.numTargets_; + for (int b = 0; b < B; ++b) { + seqLogits.push_back( + TensorView({maxT, maxU, D}, logits + b * maxT * maxU * D)); + seqTargets.push_back(targets + b * (maxU - 1)); + seqDenoms.push_back(TensorView( + {maxT, maxU}, denominators + b * maxT * maxU)); + seq_alphas.push_back( + TensorView({maxT, maxU}, alphas + b * maxT * maxU)); + seq_betas.push_back( + TensorView({maxT, maxU}, betas + b * maxT * maxU)); + seq_gradients.push_back( + TensorView({maxT, maxU, D}, gradients + b * maxT * maxU * D)); + } + + //#pragma omp parallel for + for (int b = 0; b < B; ++b) { // use max 2 * B threads. + ComputeGradientsOneSequence( + /*options=*/options, + /*logits=*/seqLogits[b], + /*targets=*/seqTargets[b], + /*srcLen=*/srcLengths[b], + /*tgtLen=*/tgtLengths[b] + 1, // with prepended blank. + /*denom=*/seqDenoms[b], + /*alpha=*/seq_alphas[b], + /*beta=*/seq_betas[b], + /*gradients=*/seq_gradients[b]); + } +} + +template +void ComputeAlphas( + const Options& options, + const CAST_DTYPE* logProbs, + const int* srcLengths, + const int* tgtLengths, + CAST_DTYPE* alphas) { + std::vector>> seqlogProbs; + std::vector> seq_alphas; + + const int& B = options.batchSize_; + const int& maxT = options.maxSrcLen_; + const int& maxU = options.maxTgtLen_; + + for (int b = 0; b < B; ++b) { + seqlogProbs.push_back(TensorView>( + {maxT, maxU}, + reinterpret_cast*>( + const_cast(logProbs)) + + b * maxT * maxU)); + seq_alphas.push_back( + TensorView({maxT, maxU}, alphas + b * maxT * maxU)); + } + + std::vector scores(B << 1); + //#pragma omp parallel for + for (int i = 0; i < B; ++i) { // use max 2 * B threads. + ComputeAlphaOneSequence( + options, + /*logProbs=*/seqlogProbs[i], + /*srcLen=*/srcLengths[i], + /*tgtLen=*/tgtLengths[i] + 1, // with prepended blank. + /*alpha=*/seq_alphas[i]); + } +} + +template +void ComputeBetas( + const Options& options, + const CAST_DTYPE* logProbs, + const int* srcLengths, + const int* tgtLengths, + CAST_DTYPE* costs, + CAST_DTYPE* betas) { + std::vector>> seqlogProbs; + std::vector> seq_betas; + + const int& B = options.batchSize_; + const int& maxT = options.maxSrcLen_; + const int& maxU = options.maxTgtLen_; + + for (int b = 0; b < B; ++b) { + seqlogProbs.push_back(TensorView>( + {maxT, maxU}, + reinterpret_cast*>( + const_cast(logProbs)) + + b * maxT * maxU)); + seq_betas.push_back( + TensorView({maxT, maxU}, betas + b * maxT * maxU)); + } + + std::vector scores(B << 1); + //#pragma omp parallel for + for (int i = 0; i < B; ++i) { // use max 2 * B threads. + ComputeBetaOneSequence( + options, + /*logProbs=*/seqlogProbs[i], + /*srcLen=*/srcLengths[i], + /*tgtLen=*/tgtLengths[i] + 1, // with prepended blank. + /*betas=*/seq_betas[i]); + } +} + +} // namespace cpu +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/cpu/cpu_transducer.h b/torchaudio/csrc/rnnt/cpu/cpu_transducer.h new file mode 100644 index 0000000000000000000000000000000000000000..9d1fc86789607ec4a6d6fc46789e2ee7a2bfb6ff --- /dev/null +++ b/torchaudio/csrc/rnnt/cpu/cpu_transducer.h @@ -0,0 +1,184 @@ +#pragma once + +#include +#include + +namespace torchaudio { +namespace rnnt { +namespace cpu { + +// Inputs: +// workspace: workspace. +// logits: pointer to (B, maxT, maxU, D) logits. +// targets: pointer to (B, maxU - 1) targets in the batch. +// srcLengths: pointer to (B, ) source lengths in the batch. +// tgtLengths: pointer to (B, ) target lengths in the batch. +// +// Outputs: +// costs: pointer to (B, ) costs in the batch. +// gradients: pointer to (B, maxT, maxU, D) gradients in the batch. +template +status_t Compute( + const Workspace& workspace, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + DTYPE* costs, + DTYPE* gradients = nullptr) { + const Options& options = workspace.GetOptions(); + + CHECK_EQ(options.device_, CPU); + + const int& B = options.batchSize_; + const int& maxT = options.maxSrcLen_; + const int& maxU = options.maxTgtLen_; + const int& D = options.numTargets_; + + { // compute denominators. + LogSumExp2D( + /*N=*/B * maxT * maxU, + /*D=*/D, + /*logits=*/logits, + /*denominators=*/workspace.GetPointerToDenominators()); + } + + { // compute log prob pairs. + ComputeLogProbs( + /*options=*/options, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*denominators=*/workspace.GetPointerToDenominators(), + /*log_probs=*/workspace.GetPointerToLogProbs()); + } + + { // compute alphas and betas. + ComputeAlphasBetas( + /*options=*/options, + /*log_probs=*/workspace.GetPointerToLogProbs(), + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*alphas=*/workspace.GetPointerToAlphas(), + /*betas=*/workspace.GetPointerToBetas(), + /*costs=*/costs); + } + + if (gradients != nullptr) { + ComputeGradients( + /*options=*/options, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*denominators=*/workspace.GetPointerToDenominators(), + /*alphas=*/workspace.GetPointerToAlphas(), + /*betas=*/workspace.GetPointerToBetas(), + /*gradients=*/gradients); + } + + return SUCCESS; +} + +template +status_t ComputeAlphas( + const Workspace& workspace, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + DTYPE* alphas) { + const Options& options = workspace.GetOptions(); + + CHECK_EQ(options.device_, CPU); + + const int& B = options.batchSize_; + const int& maxT = options.maxSrcLen_; + const int& maxU = options.maxTgtLen_; + const int& D = options.numTargets_; + + { // compute denominators. + LogSumExp2D( + /*N=*/B * maxT * maxU, + /*D=*/D, + /*logits=*/logits, + /*denominators=*/workspace.GetPointerToDenominators()); + } + + { // compute log prob pairs. + ComputeLogProbs( + /*options=*/options, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*denominators=*/workspace.GetPointerToDenominators(), + /*log_probs=*/workspace.GetPointerToLogProbs()); + } + + { // compute alphas. + ComputeAlphas( + /*options=*/options, + /*log_probs=*/workspace.GetPointerToLogProbs(), + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*alphas=*/alphas); + } + + return SUCCESS; +} + +template +status_t ComputeBetas( + const Workspace& workspace, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + DTYPE* costs, + DTYPE* betas) { + const Options& options = workspace.GetOptions(); + + CHECK_EQ(options.device_, CPU); + + const int& B = options.batchSize_; + const int& maxT = options.maxSrcLen_; + const int& maxU = options.maxTgtLen_; + const int& D = options.numTargets_; + + { // compute denominators. + LogSumExp2D( + /*N=*/B * maxT * maxU, + /*D=*/D, + /*logits=*/logits, + /*denominators=*/workspace.GetPointerToDenominators()); + } + + { // compute log prob pairs. + ComputeLogProbs( + /*options=*/options, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*denominators=*/workspace.GetPointerToDenominators(), + /*log_probs=*/workspace.GetPointerToLogProbs()); + } + + { // compute betas. + ComputeBetas( + /*options=*/options, + /*log_probs=*/workspace.GetPointerToLogProbs(), + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*costs=*/costs, + /*betas=*/betas); + } + + return SUCCESS; +} + +} // namespace cpu +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/cpu/kernel_utils.h b/torchaudio/csrc/rnnt/cpu/kernel_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..5a4b0fb8873deabfc1f4dba672dd8b148f63004a --- /dev/null +++ b/torchaudio/csrc/rnnt/cpu/kernel_utils.h @@ -0,0 +1,66 @@ +#pragma once + +#include + +#include + +namespace torchaudio { +namespace rnnt { + +inline HOST_AND_DEVICE bool in_range( + int start, + int end, // inclusive + int val) { + return start <= val && val <= end; +} + +#define LOG_PROBS_SKIP_IDX 0 +#define LOG_PROBS_EMIT_IDX 1 + +struct Indexer2D { + const int& size2_; + + FORCE_INLINE HOST_AND_DEVICE Indexer2D(const int& size2) : size2_(size2) {} + + FORCE_INLINE HOST_AND_DEVICE int operator()(int index1, int index2) { + return index1 * size2_ + index2; + } +}; + +struct Indexer3D { + const int& size2_; + const int& size3_; + + FORCE_INLINE HOST_AND_DEVICE Indexer3D(const int& size2, const int& size3) + : size2_(size2), size3_(size3) {} + + FORCE_INLINE HOST_AND_DEVICE int operator()( + int index1, + int index2, + int index3) { + return (index1 * size2_ + index2) * size3_ + index3; + } +}; + +struct Indexer4D { + const int& size2_; + const int& size3_; + const int& size4_; + + HOST_AND_DEVICE Indexer4D( + const int& size2, + const int& size3, + const int& size4) + : size2_(size2), size3_(size3), size4_(size4) {} + + HOST_AND_DEVICE int operator()( + int index1, + int index2, + int index3, + int index4) { + return ((index1 * size2_ + index2) * size3_ + index3) * size4_ + index4; + } +}; + +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/cpu/math.h b/torchaudio/csrc/rnnt/cpu/math.h new file mode 100644 index 0000000000000000000000000000000000000000..e630a65cd25e432ec15af537748b7d28c433854a --- /dev/null +++ b/torchaudio/csrc/rnnt/cpu/math.h @@ -0,0 +1,42 @@ +#pragma once + +#include + +namespace torchaudio { +namespace rnnt { + +namespace math { + +template +FORCE_INLINE HOST_AND_DEVICE DTYPE max(DTYPE x, DTYPE y) { + if (x > y) + return x; + else + return y; +} + +template +FORCE_INLINE HOST_AND_DEVICE DTYPE min(DTYPE x, DTYPE y) { + if (x > y) + return y; + else + return x; +} + +// log_sum_exp +template +FORCE_INLINE HOST_AND_DEVICE DTYPE lse(DTYPE x, DTYPE y); + +template <> +FORCE_INLINE HOST_AND_DEVICE float lse(float x, float y) { + if (y > x) { + return y + log1pf(expf(x - y)); + } else { + return x + log1pf(expf(y - x)); + } +} + +} // namespace math + +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/gpu/compute.cu b/torchaudio/csrc/rnnt/gpu/compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..0ccee481ead49d0b510e7e5426d74d15caf10eb1 --- /dev/null +++ b/torchaudio/csrc/rnnt/gpu/compute.cu @@ -0,0 +1,151 @@ +#include +#include +#include + +namespace torchaudio { +namespace rnnt { +namespace gpu { + +// Entry point into RNNT Loss +std::tuple> compute( + torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& logit_lengths, + const torch::Tensor& target_lengths, + int64_t blank, + double clamp) { + TORCH_CHECK( + logits.device().type() == targets.device().type(), + "logits and targets must be on the same device"); + TORCH_CHECK( + logits.device().type() == logit_lengths.device().type(), + "logits and logit_lengths must be on the same device"); + TORCH_CHECK( + logits.device().type() == target_lengths.device().type(), + "logits and target_lengths must be on the same device"); + + TORCH_CHECK( + logits.dtype() == torch::kFloat32 || logits.dtype() == torch::kFloat16, + "logits must be float32 or float16 (half) type"); + TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type"); + TORCH_CHECK( + logit_lengths.dtype() == torch::kInt32, + "logit_lengths must be int32 type"); + TORCH_CHECK( + target_lengths.dtype() == torch::kInt32, + "target_lengths must be int32 type"); + + TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous"); + TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); + TORCH_CHECK( + logit_lengths.is_contiguous(), "logit_lengths must be contiguous"); + TORCH_CHECK( + target_lengths.is_contiguous(), "target_lengths must be contiguous"); + + TORCH_CHECK( + logits.dim() == 4, "logits must be 4-D (batch, time, target, class)"); + TORCH_CHECK( + targets.dim() == 2, "targets must be 2-D (batch, max target length)"); + TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D"); + TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D"); + + TORCH_CHECK( + logit_lengths.size(0) == logits.size(0), + "batch dimension mismatch between logits and logit_lengths"); + TORCH_CHECK( + target_lengths.size(0) == logits.size(0), + "batch dimension mismatch between logits and target_lengths"); + TORCH_CHECK( + targets.size(0) == logits.size(0), + "batch dimension mismatch between logits and targets"); + + TORCH_CHECK( + blank >= 0 && blank < logits.size(-1), + "blank must be within [0, logits.shape[-1])"); + + TORCH_CHECK( + logits.size(1) == at::max(logit_lengths).item().toInt(), + "input length mismatch"); + TORCH_CHECK( + logits.size(2) == at::max(target_lengths).item().toInt() + 1, + "output length mismatch"); + TORCH_CHECK( + targets.size(1) == at::max(target_lengths).item().toInt(), + "target length mismatch"); + + Options options; + options.batchSize_ = logit_lengths.size(0); + options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); + options.maxSrcLen_ = logits.size(1); + options.maxTgtLen_ = logits.size(2); + options.numTargets_ = logits.size(3); + options.blank_ = blank; + options.clamp_ = clamp; + + CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); + options.stream_ = at::cuda::getCurrentCUDAStream(); + cudaSetDevice(logits.get_device()); + options.device_ = GPU; + + torch::Tensor costs = torch::empty( + options.batchSize_ * options.nHypos_, + torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + c10::optional gradients = torch::zeros_like(logits); + + torch::Tensor int_workspace = torch::empty( + IntWorkspace::ComputeSizeFromOptions(options), + torch::TensorOptions() + .device(logits.device()) + .dtype(torch::ScalarType::Int)); + + torch::Tensor float_workspace = torch::empty( + DtypeWorkspace::ComputeSizeFromOptions(options), + torch::TensorOptions() + .device(logits.device()) + .dtype(torch::ScalarType::Float)); + + Workspace workspace( + /*options=*/options, + /*dtype_data=*/float_workspace.data_ptr(), + /*dtype_size=*/float_workspace.numel(), + /*int_data=*/int_workspace.data_ptr(), + /*int_size=*/int_workspace.numel()); + + switch (logits.scalar_type()) { + case torch::ScalarType::Float: { + Compute( + /*workspace=*/workspace, + /*logits=*/logits.data_ptr(), + /*targets=*/targets.data_ptr(), + /*logit_lengths=*/logit_lengths.data_ptr(), + /*target_lengths=*/target_lengths.data_ptr(), + /*costs=*/costs.data_ptr(), + /*gradients=*/gradients->data_ptr()); + break; + } + case torch::ScalarType::Half: { + Compute( + /*workspace=*/workspace, + /*logits=*/logits.data_ptr(), + /*targets=*/targets.data_ptr(), + /*logit_lengths=*/logit_lengths.data_ptr(), + /*target_lengths=*/target_lengths.data_ptr(), + /*costs=*/costs.data_ptr(), + /*gradients=*/gradients->data_ptr()); + break; + } + default: { + break; + } + }; + + return std::make_tuple(costs, gradients); +} + +TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { + m.impl("rnnt_loss", &compute); +} + +} // namespace gpu +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/gpu/compute_alphas.cu b/torchaudio/csrc/rnnt/gpu/compute_alphas.cu new file mode 100644 index 0000000000000000000000000000000000000000..9a59b534712267d74555959fd2eb438a586502df --- /dev/null +++ b/torchaudio/csrc/rnnt/gpu/compute_alphas.cu @@ -0,0 +1,73 @@ +#include +#include +#include + +namespace torchaudio { +namespace rnnt { +namespace gpu { + +torch::Tensor compute_alphas( + const torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& logit_lengths, + const torch::Tensor& target_lengths, + int64_t blank, + double clamp) { + Options options; + options.batchSize_ = logit_lengths.size(0); + options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); + options.maxSrcLen_ = logits.size(1); + options.maxTgtLen_ = logits.size(2); + options.numTargets_ = logits.size(3); + options.blank_ = blank; + options.clamp_ = clamp; + + CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); + options.stream_ = at::cuda::getCurrentCUDAStream(); + cudaSetDevice(logits.get_device()); + options.device_ = GPU; + + torch::Tensor alphas = torch::zeros( + {options.batchSize_ * options.nHypos_, + options.maxSrcLen_, + options.maxTgtLen_}, + torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + + torch::Tensor int_workspace = torch::empty( + IntWorkspace::ComputeSizeFromOptions(options), + torch::TensorOptions() + .device(logits.device()) + .dtype(torch::ScalarType::Int)); + + torch::Tensor float_workspace = torch::empty( + DtypeWorkspace::ComputeSizeFromOptions(options), + torch::TensorOptions() + .device(logits.device()) + .dtype(torch::ScalarType::Float)); + + Workspace workspace( + /*options=*/options, + /*dtype_data=*/float_workspace.data_ptr(), + /*dtype_size=*/float_workspace.numel(), + /*int_data=*/int_workspace.data_ptr(), + /*int_size=*/int_workspace.numel()); + + // Only support float, this is mainly to enable easy + // unit-testing + ComputeAlphas( + /*workspace=*/workspace, + /*logits=*/logits.data_ptr(), + /*targets=*/targets.data_ptr(), + /*logit_lengths=*/logit_lengths.data_ptr(), + /*target_lengths=*/target_lengths.data_ptr(), + /*alphas=*/alphas.data_ptr()); + return alphas; +} + +TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { + m.impl("rnnt_loss_alphas", &compute_alphas); +} + +} // namespace gpu +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/gpu/compute_betas.cu b/torchaudio/csrc/rnnt/gpu/compute_betas.cu new file mode 100644 index 0000000000000000000000000000000000000000..75b8e2a5f71eb2622027a2b4a8b93345661f7128 --- /dev/null +++ b/torchaudio/csrc/rnnt/gpu/compute_betas.cu @@ -0,0 +1,78 @@ +#include +#include +#include + +namespace torchaudio { +namespace rnnt { +namespace gpu { + +torch::Tensor compute_betas( + const torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& logit_lengths, + const torch::Tensor& target_lengths, + int64_t blank, + double clamp) { + Options options; + options.batchSize_ = logit_lengths.size(0); + options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); + options.maxSrcLen_ = logits.size(1); + options.maxTgtLen_ = logits.size(2); + options.numTargets_ = logits.size(3); + options.blank_ = blank; + options.clamp_ = clamp; + + CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); + options.stream_ = at::cuda::getCurrentCUDAStream(); + cudaSetDevice(logits.get_device()); + options.device_ = GPU; + + torch::Tensor costs = torch::empty( + target_lengths.size(0), + torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + + torch::Tensor betas = torch::zeros( + {options.batchSize_ * options.nHypos_, + options.maxSrcLen_, + options.maxTgtLen_}, + torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); + + torch::Tensor int_workspace = torch::empty( + IntWorkspace::ComputeSizeFromOptions(options), + torch::TensorOptions() + .device(logits.device()) + .dtype(torch::ScalarType::Int)); + + torch::Tensor float_workspace = torch::empty( + DtypeWorkspace::ComputeSizeFromOptions(options), + torch::TensorOptions() + .device(logits.device()) + .dtype(torch::ScalarType::Float)); + + Workspace workspace( + /*options=*/options, + /*dtype_data=*/float_workspace.data_ptr(), + /*dtype_size=*/float_workspace.numel(), + /*int_data=*/int_workspace.data_ptr(), + /*int_size=*/int_workspace.numel()); + + // Only support float, this is mainly to enable easy + // unit-testing + ComputeBetas( + /*workspace=*/workspace, + /*logits=*/logits.data_ptr(), + /*targets=*/targets.data_ptr(), + /*logit_lengths=*/logit_lengths.data_ptr(), + /*target_lengths=*/target_lengths.data_ptr(), + /*costs=*/costs.data_ptr(), + /*betas=*/betas.data_ptr()); + return betas; +} + +TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { + m.impl("rnnt_loss_betas", &compute_betas); +} + +} // namespace gpu +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh b/torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e5f1cfc2df3e7a1ab6a73f3a041bdee47d89bd69 --- /dev/null +++ b/torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh @@ -0,0 +1,98 @@ +#pragma once + +#ifdef USE_CUDA + +#include + +namespace torchaudio { +namespace rnnt { + +template +__global__ void ReduceMax2D( + int dim, + const DTYPE* inputs, // [N, dim] + CAST_DTYPE* outputs) { + __shared__ CAST_DTYPE shared[NUM_THREADS]; + + // each thread reduces one matrix row + int offset = blockIdx.x * dim; // [n, 0] + CAST_DTYPE val = inputs[offset]; // default = inputs(n, 0) + for (int d = threadIdx.x; d < dim; d += NUM_THREADS) { + CAST_DTYPE next = inputs[offset + d]; + if (next > val) { + val = next; + } + } + + shared[threadIdx.x] = val; + __syncthreads(); + + for (int stride = (NUM_THREADS >> 1); stride >= WARP_SIZE; stride >>= 1) { + if (threadIdx.x < stride && threadIdx.x + stride < dim) { + if (shared[threadIdx.x + stride] > shared[threadIdx.x]) { + shared[threadIdx.x] = shared[threadIdx.x + stride]; + val = shared[threadIdx.x]; + } + } + __syncthreads(); + } + + CAST_DTYPE shf; + for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) { + shf = __shfl_down_sync(0xFFFFFFFF, val, stride); + if (threadIdx.x < stride && threadIdx.x + stride < dim) { + if (shf > val) { + val = shf; + } + } + } + + if (threadIdx.x == 0) { + outputs[blockIdx.x] = val; + } +} + +template +__global__ void ReduceLogSumExpGivenMax2D( + int dim, + const DTYPE* inputs, // [N, dim] + CAST_DTYPE* outputs) { // in: max -> out: logsum + + __shared__ CAST_DTYPE shared[NUM_THREADS]; + + CAST_DTYPE max = outputs[blockIdx.x]; + CAST_DTYPE val = 0; + + int offset = blockIdx.x * dim; + for (int d = threadIdx.x; d < dim; d += NUM_THREADS) { + val = val + std::exp(CAST_DTYPE(inputs[offset + d]) - max); + } + + shared[threadIdx.x] = val; + __syncthreads(); + + for (int stride = (NUM_THREADS >> 1); stride >= WARP_SIZE; stride >>= 1) { + if (threadIdx.x < stride && threadIdx.x + stride < dim) { + val = shared[threadIdx.x] + shared[threadIdx.x + stride]; + shared[threadIdx.x] = val; + } + __syncthreads(); + } + + CAST_DTYPE shf; + for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) { + shf = __shfl_down_sync(0xFFFFFFFF, val, stride); + if (threadIdx.x < stride && threadIdx.x + stride < dim) { + val = val + shf; + } + } + + if (threadIdx.x == 0) { + outputs[blockIdx.x] = max + std::log(val); + } +} + +} // namespace rnnt +} // namespace torchaudio + +#endif // USE_CUDA diff --git a/torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh b/torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh new file mode 100644 index 0000000000000000000000000000000000000000..90b5ebfd7e4be8587357074d9051d36da2356941 --- /dev/null +++ b/torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh @@ -0,0 +1,409 @@ +#pragma once + +#ifdef USE_CUDA + +#include + +#include +#include +#include + +namespace torchaudio { +namespace rnnt { + +template +__global__ void ComputeLogProbs( + int maxSrcLen, + int maxTgtLen, + int numTargets, + int blank, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + const CAST_DTYPE* denominators, + CAST_DTYPE* logProbs, + int H = 1) { + const int& maxT = maxSrcLen; + const int& maxU = maxTgtLen; + const int& D = numTargets; + + const int bTgt = blockIdx.z; // 0 <= b < B + const int bSrc = bTgt / H; + const int T = srcLengths[bSrc]; + const int U = tgtLengths[bTgt] + 1; + + const int t = blockIdx.x * blockDim.x + threadIdx.x; + const int u = blockIdx.y; + + if (t >= T || u >= U) { // out of boundary. + return; + } + + Indexer3D indexer(maxT, maxU); + + int idx = indexer(bTgt, t, u); + + // skip: log_prob(b, t, u).skip() = logits(b, t, u, blank) - denom(b, t, u). + logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] = + CAST_DTYPE(logits[idx * D + blank]) - denominators[idx]; + + if (u < U - 1) { + // emit: log_prob(b, t, u).emit() = logits(b, t, u, tgt[u]) - denom(b, t, + // u). + int target = targets[Indexer2D(maxU - 1)(bTgt, u)]; + logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] = + CAST_DTYPE(logits[idx * D + target]) - denominators[idx]; + } +} + +template +__device__ void ComputeAlphas( + int maxSrcLen, + int maxTgtLen, + int numTargets, + int blank, + const CAST_DTYPE* logProbs, + const int* srcLengths, + const int* tgtLengths, + int* alpha_counters, + volatile CAST_DTYPE* alphas, + int H = 1) { + const int& maxT = maxSrcLen; + const int& maxU = maxTgtLen; + + const int bTgt = blockIdx.z; // 0 <= b < B + const int bSrc = bTgt / H; + const int T = srcLengths[bSrc]; + const int U = tgtLengths[bTgt] + 1; + + const int t = blockIdx.x * blockDim.x + threadIdx.x + 1; + const int u = blockIdx.y + 1; + + if (t >= T || u >= U) { // out of boundary. + return; + } + + int* counter = alpha_counters + Indexer2D(maxU)(bTgt, blockIdx.y); + + Indexer3D idxr(maxT, maxU); + + if (t == 1 && u == 1) { + alphas[idxr(bTgt, 0, 0)] = 0; + } + + if (blockIdx.x > 0) { // wait for previous warp (in t-axis) is ready. + while (atomicAdd(counter, 0) < blockIdx.x) { + } + } + + if (blockIdx.y > 0) { // wait for previous warp (in u-axis) is ready. + while (atomicAdd(counter - 1, 0) <= blockIdx.x) { + } + } + + if (t == 1 && u < U) { + // alpha(0, u) = alpha(0, u - 1) + logProbs(0, u - 1).emit(). + alphas[idxr(bTgt, 0, u)] = alphas[idxr(bTgt, 0, u - 1)] + + logProbs[(idxr(bTgt, 0, u - 1) << 1) + LOG_PROBS_EMIT_IDX]; + } + + if (blockIdx.y == 0 && t < T) { + CAST_DTYPE skip_prob = + logProbs[(idxr(bTgt, t - 1, 0) << 1) + LOG_PROBS_SKIP_IDX]; + CAST_DTYPE val; + +#pragma unroll + for (int i = 1; i < warpSize; i <<= 1) { + val = __shfl_up_sync(0xffffffff, skip_prob, i); + if (i <= threadIdx.x) { + skip_prob = skip_prob + val; + } + } + + val = alphas[idxr(bTgt, blockIdx.x * blockDim.x, 0)]; + alphas[idxr(bTgt, t, 0)] = skip_prob + val; + } + + if (t < T && u < U) { + CAST_DTYPE skip_prob = + logProbs[(idxr(bTgt, t - 1, u) << 1) + LOG_PROBS_SKIP_IDX]; + CAST_DTYPE emit_prob = + logProbs[(idxr(bTgt, t, u - 1) << 1) + LOG_PROBS_EMIT_IDX]; + + CAST_DTYPE skip = + alphas[idxr(bTgt, blockIdx.x * blockDim.x, u)] + skip_prob; + CAST_DTYPE emit = alphas[idxr(bTgt, t, u - 1)] + emit_prob; + + CAST_DTYPE val = math::lse(skip, emit); + CAST_DTYPE out = val; + + for (int i = 1; i < warpSize; ++i) { + val = __shfl_up_sync(0xffffffff, val, 1); + if (i == threadIdx.x) { + val = math::lse(val + skip_prob, emit); + out = val; + } + } + + alphas[idxr(bTgt, t, u)] = out; + } + + if (threadIdx.x == 0) { + __threadfence(); + atomicAdd(counter, 1); + } +} + +template +__device__ void ComputeBetasCosts( + int maxSrcLen, + int maxTgtLen, + int numTargets, + int blank, + const CAST_DTYPE* logProbs, + const int* srcLengths, + const int* tgtLengths, + int* betaCounters, + volatile CAST_DTYPE* betas, + DTYPE* costs, + int H = 1) { + const int& maxT = maxSrcLen; + const int& maxU = maxTgtLen; + + const int bTgt = blockIdx.z; // 0 <= b < B + const int bSrc = bTgt / H; + const int T = srcLengths[bSrc]; + const int U = tgtLengths[bTgt] + 1; + + const int t = T - 2 - blockIdx.x * blockDim.x - threadIdx.x; + const int u = U - 2 - blockIdx.y; + + if (t < 0 || u < 0) { // out of boundary. + return; + } + + int* counter = betaCounters + Indexer2D(maxU)(bTgt, blockIdx.y); + + Indexer3D idxr(maxT, maxU); + + if (t == T - 2 && u == U - 2) { + betas[idxr(bTgt, T - 1, U - 1)] = + logProbs[(idxr(bTgt, T - 1, U - 1) << 1) + LOG_PROBS_SKIP_IDX]; + } + + if (blockIdx.x > 0) { // wait for previous warp (in t-axis) is ready. + while (atomicAdd(counter, 0) < blockIdx.x) { + } + } + + if (blockIdx.y > 0) { // wait for previous warp (in u-axis) is ready. + while (atomicAdd(counter - 1, 0) <= blockIdx.x) { + } + } + + if (t == T - 2 && u >= 0) { + betas[idxr(bTgt, T - 1, u)] = betas[idxr(bTgt, T - 1, u + 1)] + + logProbs[(idxr(bTgt, T - 1, u) << 1) + LOG_PROBS_EMIT_IDX]; + } + + if (blockIdx.y == 0 && t >= 0) { + CAST_DTYPE skip_prob = + logProbs[(idxr(bTgt, t, U - 1) << 1) + LOG_PROBS_SKIP_IDX]; + CAST_DTYPE val; + +#pragma unroll + for (int i = 1; i < warpSize; i <<= 1) { + val = __shfl_up_sync(0xffffffff, skip_prob, i); + if (i <= threadIdx.x) { + skip_prob = skip_prob + val; + } + } + + betas[idxr(bTgt, t, U - 1)] = + betas[idxr(bTgt, T - 1 - blockIdx.x * blockDim.x, U - 1)] + skip_prob; + } + + if (t >= 0 && u >= 0) { + CAST_DTYPE skip_prob = + logProbs[(idxr(bTgt, t, u) << 1) + LOG_PROBS_SKIP_IDX]; + CAST_DTYPE emit_prob = + logProbs[(idxr(bTgt, t, u) << 1) + LOG_PROBS_EMIT_IDX]; + + CAST_DTYPE skip = betas[idxr(bTgt, t + threadIdx.x + 1, u)] + skip_prob; + CAST_DTYPE emit = betas[idxr(bTgt, t, u + 1)] + emit_prob; + + CAST_DTYPE val = math::lse(skip, emit); + CAST_DTYPE out = val; + + for (int i = 1; i < warpSize; ++i) { + val = __shfl_up_sync(0xffffffff, val, 1); + if (i == threadIdx.x) { + val = math::lse(val + skip_prob, emit); + out = val; + } + } + + betas[idxr(bTgt, t, u)] = out; + + if (t == 0 && u == 0) { // use -beta(0, 0) as cost. + costs[bTgt] = DTYPE(-out); + } + } + + if (threadIdx.x == 0) { + __threadfence(); + atomicAdd(counter, 1); + } +} + +template +__global__ void ComputeAlphasBetasCosts( + int maxSrcLen, + int maxTgtLen, + int numTargets, + int blank, + const CAST_DTYPE* logProbs, + const int* srcLengths, + const int* tgtLengths, + int* alpha_counters, + volatile CAST_DTYPE* alphas, + int* betaCounters, + volatile CAST_DTYPE* betas, + DTYPE* costs, + int warpSize = 0, + int numWarps = 0, + int H = 1) { + assert(threadIdx.y == 0 || threadIdx.y == 1); + + if (threadIdx.y == 0) { + ComputeAlphas( + /*maxSrcLen=*/maxSrcLen, + /*maxTgtLen=*/maxTgtLen, + /*numTargets=*/numTargets, + /*blank=*/blank, + /*logProbs=*/logProbs, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*alpha_counters=*/alpha_counters, + /*alphas=*/alphas, + H); + } else { // threadIdx.y == 1 + ComputeBetasCosts( + /*maxSrcLen=*/maxSrcLen, + /*maxTgtLen=*/maxTgtLen, + /*numTargets=*/numTargets, + /*blank=*/blank, + /*logProbs=*/logProbs, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*betaCounters=*/betaCounters, + /*beta=*/betas, + /*costs=*/costs, + H); + } +} + +template +__global__ void ComputeGradients( + int maxSrcLen, + int maxTgtLen, + int numTargets, + int blank, + CAST_DTYPE clamp, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + const CAST_DTYPE* denominators, + const CAST_DTYPE* alphas, + const CAST_DTYPE* betas, + DTYPE* gradients, + int H = 1) { + const int bTgt = blockIdx.z; // 0 <= b < B + const int t = blockIdx.x * blockDim.x + threadIdx.x; + const int u = blockIdx.y; + + ComputeGradientsElement( + bTgt, + t, + u, + maxSrcLen, + maxTgtLen, + numTargets, + blank, + clamp, + logits, + targets, + srcLengths, + tgtLengths, + denominators, + alphas, + betas, + gradients, + H); +} + +// This is a __global__ wrapper around ComputeAlphas +// device kernel to enable unit testing +template +__global__ void ComputeAlphasWrapper( + int maxSrcLen, + int maxTgtLen, + int numTargets, + int blank, + const CAST_DTYPE* logProbs, + const int* srcLengths, + const int* tgtLengths, + int* alpha_counters, + volatile CAST_DTYPE* alphas, + int H = 1) { + ComputeAlphas( + maxSrcLen, + maxTgtLen, + numTargets, + blank, + logProbs, + srcLengths, + tgtLengths, + alpha_counters, + alphas, + H); +} + +// This is a __global__ wrapper around ComputeBetas +// device kernel to enable unit testing +template +__global__ void ComputeBetasWrapper( + int maxSrcLen, + int maxTgtLen, + int numTargets, + int blank, + const CAST_DTYPE* logProbs, + const int* srcLengths, + const int* tgtLengths, + int* betaCounters, + volatile CAST_DTYPE* betas, + DTYPE* costs, + int H = 1) { + ComputeBetasCosts( + maxSrcLen, + maxTgtLen, + numTargets, + blank, + logProbs, + srcLengths, + tgtLengths, + betaCounters, + betas, + costs, + H); +} + +// #undef LOG_PROBS_SKIP_IDX +// #undef LOG_PROBS_EMIT_IDX + +} // namespace rnnt +} // namespace torchaudio + +#endif // USE_CUDA diff --git a/torchaudio/csrc/rnnt/gpu/gpu_transducer.h b/torchaudio/csrc/rnnt/gpu/gpu_transducer.h new file mode 100644 index 0000000000000000000000000000000000000000..54d16b9f215190289f766c72f93dca4075d5856a --- /dev/null +++ b/torchaudio/csrc/rnnt/gpu/gpu_transducer.h @@ -0,0 +1,391 @@ +#pragma once + +#ifdef USE_CUDA + +#include +#include +#include + +namespace torchaudio { +namespace rnnt { +namespace gpu { + +#define gpuErrchk(ans) \ + { gpuAssert((ans), __FILE__, __LINE__); } + +inline void gpuAssert( + cudaError_t code, + const char* file, + int line, + bool abort = true) { + if (code != cudaSuccess) { + fprintf( + stderr, + "\nGPUassert: %s %s %d\n", + cudaGetErrorString(code), + file, + line); + if (abort) + exit(code); + } +} + +template +status_t LogSumExp2D( + cudaStream_t stream, + int N, + int D, + const DTYPE* logits, // [N, D] + CAST_DTYPE* outputs) { + { // compute max among D. + dim3 block_dims(N); + dim3 thread_dims(REDUCE_THREADS); + + ReduceMax2D + <<>>( + /*dim=*/D, + /*inputs=*/logits, + /*outputs=*/outputs); + + // BUGBUG: These error codes are only accurate when launching with + // blocking. Otherwise they usually reflect earlier errors. + if (cudaGetLastError() != cudaSuccess) { + return COMPUTE_DENOMINATOR_REDUCE_MAX_FAILED; + } + } + + { // compute log(sum(exp(d_i - max))) + dim3 block_dims(N); + dim3 thread_dims(REDUCE_THREADS); + + ReduceLogSumExpGivenMax2D + <<>>( + /*dim=*/D, + /*inputs=*/logits, + /*outputs=*/outputs); + + if (cudaGetLastError() != cudaSuccess) { + return COMPUTE_DENOMINATOR_REDUCE_SUM_FAILED; + } + } + + return SUCCESS; +} + +// Inputs: +// workspace: workspace. +// logits: pointer to (B, max_T, max_U, D) logits. +// targets: pointer to (B, max_U - 1) targets in the batch. +// srcLengths: pointer to (B, ) source lengths in the batch. +// tgtLengths: pointer to (B, ) target lengths in the batch. +// +// Outputs: +// costs: pointer to (B, ) costs in the batch. +// gradients: pointer to (B, max_T, max_U, D) gradients in the batch. +template +status_t Compute( + const Workspace& workspace, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + DTYPE* costs, + DTYPE* gradients = nullptr) { + const Options& options = workspace.GetOptions(); + + const cudaStream_t& stream = options.stream_; + const int& B = options.batchSize_; + const int& H = options.nHypos_; + const int& max_T = options.maxSrcLen_; + const int& max_U = options.maxTgtLen_; + const int& D = options.numTargets_; + const int& blank = options.blank_; + const CAST_DTYPE clamp = options.clamp_; + + { // compute denominators. + status_t status = LogSumExp2D( + /*stream=*/stream, + /*N=*/B * H * max_T * max_U, + /*D=*/D, + /*logits=*/logits, + /*denominators=*/workspace.GetPointerToDenominators()); + + if (status != SUCCESS) { + return status; + } + } + + { // compute log probability pairs (blank and target). + int num_segments = + (max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK; + dim3 block_dims(num_segments, max_U, B * H); + dim3 thread_dims(MAX_THREADS_PER_BLOCK); + + ComputeLogProbs<<>>( + /*max_src_len=*/max_T, + /*max_tgt_len=*/max_U, + /*num_targets=*/D, + /*blank=*/blank, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*denominators=*/workspace.GetPointerToDenominators(), + /*log_probs=*/workspace.GetPointerToLogProbs(), + H); + + if (cudaGetLastError() != cudaSuccess) { + return COMPUTE_LOG_PROBS_FAILED; + } + } + + { // compute alphas, betas and costs. + // warp is usually a group of threads (32) + int num_warps = (max_T + WARP_SIZE - 1) / WARP_SIZE; + + // each block is identified by 3 d tuple. + // we are using num_warp * max_U * B * H blocks + // where num_warp is division among Time axis + dim3 block_dims(num_warps, max_U, B * H); + + // each thread is identified by a 2 d tuple + // 2nd dim is 2. 1 for alpha, 1 for beta + dim3 thread_dims(WARP_SIZE, 2); + + ComputeAlphasBetasCosts + <<>>( + /*max_src_len=*/max_T, + /*max_tgt_len=*/max_U, + /*num_targets=*/D, + /*blank=*/blank, + /*log_probs=*/workspace.GetPointerToLogProbs(), + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*alpha_counters=*/workspace.GetPointerToAlphaCounters(), + /*alphas=*/workspace.GetPointerToAlphas(), + /*beta_counters=*/workspace.GetPointerToBetaCounters(), + /*betas=*/workspace.GetPointerToBetas(), + /*costs=*/costs, + /*warp_size=*/WARP_SIZE, + /*num_warps=*/num_warps, + H); + if (cudaGetLastError() != cudaSuccess) { + return COMPUTE_ALPHAS_BETAS_COSTS_FAILED; + } + } + + if (gradients != nullptr) { // compute gradients. + // don't set gradients to zero to here as gradients might reuse memory from + // logits + + int num_blocks = + (max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK; + dim3 block_dims(num_blocks, max_U, B * H); + dim3 thread_dims(MAX_THREADS_PER_BLOCK); + + ComputeGradients<<>>( + /*max_src_len=*/max_T, + /*max_tgt_len=*/max_U, + /*num_targets=*/D, + /*blank=*/blank, + /*clamp=*/clamp, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*denominators=*/workspace.GetPointerToDenominators(), + /*alphas=*/workspace.GetPointerToAlphas(), + /*betas=*/workspace.GetPointerToBetas(), + /*gradients=*/gradients, + H); + if (cudaGetLastError() != cudaSuccess) { + return COMPUTE_GRADIENTS_FAILED; + } + } + + return SUCCESS; +} + +template +status_t ComputeAlphas( + const Workspace& workspace, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + DTYPE* alphas) { + const Options& options = workspace.GetOptions(); + + const cudaStream_t& stream = options.stream_; + const int& B = options.batchSize_; + const int& H = options.nHypos_; + const int& max_T = options.maxSrcLen_; + const int& max_U = options.maxTgtLen_; + const int& D = options.numTargets_; + const int& blank = options.blank_; + + { // compute denominators. + status_t status = LogSumExp2D( + /*stream=*/stream, + /*N=*/B * H * max_T * max_U, + /*D=*/D, + /*logits=*/logits, + /*denominators=*/workspace.GetPointerToDenominators()); + + if (status != SUCCESS) { + return status; + } + } + + { // compute log probability pairs (blank and target). + int num_segments = + (max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK; + dim3 block_dims(num_segments, max_U, B * H); + dim3 thread_dims(MAX_THREADS_PER_BLOCK); + + ComputeLogProbs<<>>( + /*max_src_len=*/max_T, + /*max_tgt_len=*/max_U, + /*num_targets=*/D, + /*blank=*/blank, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*denominators=*/workspace.GetPointerToDenominators(), + /*log_probs=*/workspace.GetPointerToLogProbs(), + H); + + if (cudaGetLastError() != cudaSuccess) { + return COMPUTE_LOG_PROBS_FAILED; + } + } + { // compute alphas + // warp is usually a group of threads (32) + int num_warps = (max_T + WARP_SIZE - 1) / WARP_SIZE; + + // each block is identified by 3 d tuple. + // we are using num_warp * max_U * B blocks + // where num_warp is division among Time axis + dim3 block_dims(num_warps, max_U, B * H); + + // each thread is identified by a 2 d tuple + // 2nd dim is 1 for alpha only + dim3 thread_dims(WARP_SIZE, 1); + + ComputeAlphasWrapper + <<>>( + /*max_src_len=*/max_T, + /*max_tgt_len=*/max_U, + /*num_targets=*/D, + /*blank=*/blank, + /*log_probs=*/workspace.GetPointerToLogProbs(), + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*alpha_counters=*/workspace.GetPointerToAlphaCounters(), + /*alphas=*/(volatile DTYPE*)alphas, + H); + + if (cudaGetLastError() != cudaSuccess) { + return COMPUTE_ALPHAS_BETAS_COSTS_FAILED; + } + } + + return SUCCESS; +} + +template +status_t ComputeBetas( + const Workspace& workspace, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + DTYPE* costs, + DTYPE* betas) { + const Options& options = workspace.GetOptions(); + + const cudaStream_t& stream = options.stream_; + const int& B = options.batchSize_; + const int& H = options.nHypos_; + const int& max_T = options.maxSrcLen_; + const int& max_U = options.maxTgtLen_; + const int& D = options.numTargets_; + const int& blank = options.blank_; + + { // compute denominators. + status_t status = LogSumExp2D( + /*stream=*/stream, + /*N=*/B * H * max_T * max_U, + /*D=*/D, + /*logits=*/logits, + /*denominators=*/workspace.GetPointerToDenominators()); + + if (status != SUCCESS) { + return status; + } + } + + { // compute log probability pairs (blank and target). + int num_segments = + (max_T + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK; + dim3 block_dims(num_segments, max_U, B * H); + dim3 thread_dims(MAX_THREADS_PER_BLOCK); + + ComputeLogProbs<<>>( + /*max_src_len=*/max_T, + /*max_tgt_len=*/max_U, + /*num_targets=*/D, + /*blank=*/blank, + /*logits=*/logits, + /*targets=*/targets, + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*denominators=*/workspace.GetPointerToDenominators(), + /*log_probs=*/workspace.GetPointerToLogProbs(), + H); + + if (cudaGetLastError() != cudaSuccess) { + return COMPUTE_LOG_PROBS_FAILED; + } + } + { // compute betas + // warp is usually a group of threads (32) + int num_warps = (max_T + WARP_SIZE - 1) / WARP_SIZE; + + // each block is identified by 3 d tuple. + // we are using num_warp * max_U * B blocks + // where num_warp is division among Time axis + dim3 block_dims(num_warps, max_U, B * H); + + // each thread is identified by a 2 d tuple + // 2nd dim is 1 for betas only + dim3 thread_dims(WARP_SIZE, 1); + + ComputeBetasWrapper + <<>>( + /*max_src_len=*/max_T, + /*max_tgt_len=*/max_U, + /*num_targets=*/D, + /*blank=*/blank, + /*log_probs=*/workspace.GetPointerToLogProbs(), + /*srcLengths=*/srcLengths, + /*tgtLengths=*/tgtLengths, + /*alpha_counters=*/workspace.GetPointerToBetaCounters(), + /*alphas=*/(volatile DTYPE*)betas, + costs, + H); + + if (cudaGetLastError() != cudaSuccess) { + return COMPUTE_ALPHAS_BETAS_COSTS_FAILED; + } + } + + return SUCCESS; +} + +} // namespace gpu +} // namespace rnnt +} // namespace torchaudio + +#endif // USE_CUDA diff --git a/torchaudio/csrc/rnnt/gpu/half.cuh b/torchaudio/csrc/rnnt/gpu/half.cuh new file mode 100644 index 0000000000000000000000000000000000000000..72a2f37e04efe2129efa9d070ff8d24bcd24457d --- /dev/null +++ b/torchaudio/csrc/rnnt/gpu/half.cuh @@ -0,0 +1,38 @@ +#pragma once + +#ifdef USE_C10_HALF +#include "c10/util/Half.h" +#endif // USE_C10_HALF + +#include + +namespace torchaudio { +namespace rnnt { + +struct alignas(sizeof(__half)) Half { + __half x; + + HOST_AND_DEVICE Half() = default; + + FORCE_INLINE HOST_AND_DEVICE Half(float f) { + x = __float2half_rn(f); + if (isinf(__half2float(x))) { + x = __float2half_rz(f); // round toward 0. + } + } + + FORCE_INLINE HOST_AND_DEVICE operator float() const { + return __half2float(x); + } + + FORCE_INLINE HOST_AND_DEVICE Half(__half f) { + x = f; + } + + FORCE_INLINE HOST_AND_DEVICE operator __half() const { + return x; + } +}; + +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/gpu/kernel_utils.h b/torchaudio/csrc/rnnt/gpu/kernel_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..3b2989b07378f39475795522e888fc9a642e0ab0 --- /dev/null +++ b/torchaudio/csrc/rnnt/gpu/kernel_utils.h @@ -0,0 +1,66 @@ +#pragma once + +#include + +#include + +namespace torchaudio { +namespace rnnt { + +inline HOST_AND_DEVICE bool in_range( + int start, + int end, // inclusive + int val) { + return start <= val && val <= end; +} + +#define LOG_PROBS_SKIP_IDX 0 +#define LOG_PROBS_EMIT_IDX 1 + +struct Indexer2D { + const int& size2_; + + FORCE_INLINE HOST_AND_DEVICE Indexer2D(const int& size2) : size2_(size2) {} + + FORCE_INLINE HOST_AND_DEVICE int operator()(int index1, int index2) { + return index1 * size2_ + index2; + } +}; + +struct Indexer3D { + const int& size2_; + const int& size3_; + + FORCE_INLINE HOST_AND_DEVICE Indexer3D(const int& size2, const int& size3) + : size2_(size2), size3_(size3) {} + + FORCE_INLINE HOST_AND_DEVICE int operator()( + int index1, + int index2, + int index3) { + return (index1 * size2_ + index2) * size3_ + index3; + } +}; + +struct Indexer4D { + const int& size2_; + const int& size3_; + const int& size4_; + + HOST_AND_DEVICE Indexer4D( + const int& size2, + const int& size3, + const int& size4) + : size2_(size2), size3_(size3), size4_(size4) {} + + HOST_AND_DEVICE int operator()( + int index1, + int index2, + int index3, + int index4) { + return ((index1 * size2_ + index2) * size3_ + index3) * size4_ + index4; + } +}; + +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/gpu/kernels.h b/torchaudio/csrc/rnnt/gpu/kernels.h new file mode 100644 index 0000000000000000000000000000000000000000..b0627c2181c8a67df29fc0c542c671ee0bc7f8f5 --- /dev/null +++ b/torchaudio/csrc/rnnt/gpu/kernels.h @@ -0,0 +1,108 @@ +#pragma once + +#include + +#include +#include + +namespace torchaudio { +namespace rnnt { + +template +HOST_AND_DEVICE void ComputeGradientsElement( + int bTgt, + int t, + int u, + int maxSrcLen, + int maxTgtLen, + int numTargets, + int blank, + CAST_DTYPE clamp, + const DTYPE* logits, + const int* targets, + const int* srcLengths, + const int* tgtLengths, + const CAST_DTYPE* denominators, + const CAST_DTYPE* alphas, + const CAST_DTYPE* betas, + DTYPE* gradients, + int H = 1) { + const int& maxT = maxSrcLen; + const int& maxU = maxTgtLen; + const int& D = numTargets; + + const int bSrc = bTgt / H; + const int T = srcLengths[bSrc]; + const int U = tgtLengths[bTgt] + 1; + + if (t >= T || u >= U) { // out of boundary. + if (gradients == logits && t < maxT && u < maxU) { + // gradients and logits are pointing to the same memory location + Indexer3D idxr3(maxT, maxU); + int idx_b_t_u_zero = idxr3(bTgt, t, u); + if (idx_b_t_u_zero != -1) { + int start = idx_b_t_u_zero * D; + for (int b_t_u_d = start; b_t_u_d < start + D; ++b_t_u_d) { + gradients[b_t_u_d] = 0; + } + } + } + return; + } + + int costIdx = bTgt * maxT * maxU; + CAST_DTYPE cost = -(betas[costIdx]); + + Indexer2D idxr2(maxU - 1); + + int idx_b_t_u, idx_b_t_up1, idx_b_tp1_u; + Indexer3D idxr3(maxT, maxU); + idx_b_t_u = idxr3(bTgt, t, u); + idx_b_t_up1 = idxr3(bTgt, t, u + 1); + idx_b_tp1_u = idxr3(bTgt, t + 1, u); + + if (idx_b_t_u == -1) { + return; + } + + if (isinf(cost) || isnan(cost)) { + for (int d = 0; d < D; ++d) { + int b_t_u_d = idx_b_t_u * D + d; + gradients[b_t_u_d] = 0; + } + return; + } + + CAST_DTYPE c = alphas[idx_b_t_u] + cost - denominators[idx_b_t_u]; + for (int d = 0; d < D; ++d) { + int b_t_u_d = idx_b_t_u * D + d; + CAST_DTYPE g = CAST_DTYPE(logits[b_t_u_d]) + c; + + if (d == blank && t == T - 1 && u == U - 1) { // last blank transition. + gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]) - std::exp(g); + } else if (t < T - 1 && d == blank) { + gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]); + if (idx_b_tp1_u != -1) { + gradients[b_t_u_d] = + gradients[b_t_u_d] - std::exp(g + betas[idx_b_tp1_u]); + } + } else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) { + gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]); + if (idx_b_t_up1 != -1) { + gradients[b_t_u_d] = + gradients[b_t_u_d] - std::exp(g + betas[idx_b_t_up1]); + } + } else { + gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]); + } + + if (clamp > 0) { + auto g = CAST_DTYPE(gradients[b_t_u_d]); + gradients[b_t_u_d] = math::min(g, clamp); + gradients[b_t_u_d] = math::max(g, -clamp); + } + } +} + +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/gpu/math.cuh b/torchaudio/csrc/rnnt/gpu/math.cuh new file mode 100644 index 0000000000000000000000000000000000000000..643fa98300c1b58fd817c8aac92f8022005618c4 --- /dev/null +++ b/torchaudio/csrc/rnnt/gpu/math.cuh @@ -0,0 +1,48 @@ +#pragma once + +#ifdef USE_CUDA + +#include + +#endif // USE_CUDA + +#include + +namespace torchaudio { +namespace rnnt { + +namespace math { + +template +FORCE_INLINE HOST_AND_DEVICE DTYPE max(DTYPE x, DTYPE y) { + if (x > y) + return x; + else + return y; +} + +template +FORCE_INLINE HOST_AND_DEVICE DTYPE min(DTYPE x, DTYPE y) { + if (x > y) + return y; + else + return x; +} + +// log_sum_exp +template +FORCE_INLINE HOST_AND_DEVICE DTYPE lse(DTYPE x, DTYPE y); + +template <> +FORCE_INLINE HOST_AND_DEVICE float lse(float x, float y) { + if (y > x) { + return y + log1pf(expf(x - y)); + } else { + return x + log1pf(expf(y - x)); + } +} + +} // namespace math + +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/macros.cpp b/torchaudio/csrc/rnnt/macros.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d2ea30a69be14db2fa47745aed98ddf617650b11 --- /dev/null +++ b/torchaudio/csrc/rnnt/macros.cpp @@ -0,0 +1,16 @@ +#include + +const char* ToString(level_t level) { + switch (level) { + case INFO: + return "INFO"; + case WARNING: + return "WARNING"; + case ERROR: + return "ERROR"; + case FATAL: + return "FATAL"; + default: + return "UNKNOWN"; + } +} diff --git a/torchaudio/csrc/rnnt/macros.h b/torchaudio/csrc/rnnt/macros.h new file mode 100644 index 0000000000000000000000000000000000000000..abcbc3996645c21b1e3415dedd4cd6ba129554ee --- /dev/null +++ b/torchaudio/csrc/rnnt/macros.h @@ -0,0 +1,21 @@ +#pragma once + +#ifdef USE_CUDA +#define WARP_SIZE 32 +#define MAX_THREADS_PER_BLOCK 1024 +#define REDUCE_THREADS 256 +#define HOST_AND_DEVICE __host__ __device__ +#define FORCE_INLINE __forceinline__ +#include +#include +#else +#define HOST_AND_DEVICE +#define FORCE_INLINE inline +#endif // USE_CUDA + +#include +#include + +typedef enum { INFO = 0, WARNING = 1, ERROR = 2, FATAL = 3 } level_t; + +const char* ToString(level_t level); diff --git a/torchaudio/csrc/rnnt/options.h b/torchaudio/csrc/rnnt/options.h new file mode 100644 index 0000000000000000000000000000000000000000..79109950fdebcd326dddfaccad7bcd9d5d129f98 --- /dev/null +++ b/torchaudio/csrc/rnnt/options.h @@ -0,0 +1,77 @@ +#pragma once + +//#include + +#ifdef USE_CUDA +#include +#endif // USE_CUDA + +#include +#include + +namespace torchaudio { +namespace rnnt { + +typedef struct Options { + // the device to compute transducer loss. + device_t device_; +#ifdef USE_CUDA + // the stream to launch kernels in when using GPU. + cudaStream_t stream_; +#endif + // The maximum number of threads that can be used. + int numThreads_; + + // the index for "blank". + int blank_; + // whether to backtrack the best path. + bool backtrack_; + // gradient clamp value. + float clamp_; + + // batch size = B. + int batchSize_; + + // Number of hypos per sample = H + int nHypos_; + + // the maximum length of src encodings = max_T. + int maxSrcLen_; + // the maximum length of tgt encodings = max_U. + int maxTgtLen_; + // num_targets = D. + int numTargets_; + + Options() + : device_(UNDEFINED), + numThreads_(0), + blank_(-1), + backtrack_(false), + clamp_(-1), // negative for disabling clamping by default. + batchSize_(0), + nHypos_(1), + maxSrcLen_(0), + maxTgtLen_(0), + numTargets_(0) {} + + int BU() const { + return batchSize_ * maxTgtLen_ * nHypos_; + } + + int BTU() const { + return batchSize_ * maxSrcLen_ * maxTgtLen_ * nHypos_; + } + + friend std::ostream& operator<<(std::ostream& os, const Options& options) { + os << "Options(" + << "batchSize_=" << options.batchSize_ << ", " + << "maxSrcLen_=" << options.maxSrcLen_ << ", " + << "maxTgtLen_=" << options.maxTgtLen_ << ", " + << "numTargets_=" << options.numTargets_ << ")"; + + return os; + } +} Options; + +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/types.cpp b/torchaudio/csrc/rnnt/types.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3c08a7eeb3f6af4e2b25ff9b947544837e7cb755 --- /dev/null +++ b/torchaudio/csrc/rnnt/types.cpp @@ -0,0 +1,41 @@ +#include + +namespace torchaudio { +namespace rnnt { + +const char* toString(status_t status) { + switch (status) { + case SUCCESS: + return "success"; + case FAILURE: + return "failure"; + case COMPUTE_DENOMINATOR_REDUCE_MAX_FAILED: + return "compute_denominator_reduce_max_failed"; + case COMPUTE_DENOMINATOR_REDUCE_SUM_FAILED: + return "compute_denominator_reduce_sum_failed"; + case COMPUTE_LOG_PROBS_FAILED: + return "compute_log_probs_failed"; + case COMPUTE_ALPHAS_BETAS_COSTS_FAILED: + return "compute_alphas_betas_costs_failed"; + case COMPUTE_GRADIENTS_FAILED: + return "compute_gradients_failed"; + default: + return "unknown"; + } +} + +const char* toString(device_t device) { + switch (device) { + case UNDEFINED: + return "undefined"; + case CPU: + return "cpu"; + case GPU: + return "gpu"; + default: + return "unknown"; + } +} + +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/types.h b/torchaudio/csrc/rnnt/types.h new file mode 100644 index 0000000000000000000000000000000000000000..34d2998cffdcd4527aec1635d8c1fdf43892b8d1 --- /dev/null +++ b/torchaudio/csrc/rnnt/types.h @@ -0,0 +1,23 @@ +#pragma once + +namespace torchaudio { +namespace rnnt { + +typedef enum { + SUCCESS = 0, + FAILURE = 1, + COMPUTE_DENOMINATOR_REDUCE_MAX_FAILED = 2, + COMPUTE_DENOMINATOR_REDUCE_SUM_FAILED = 3, + COMPUTE_LOG_PROBS_FAILED = 4, + COMPUTE_ALPHAS_BETAS_COSTS_FAILED = 5, + COMPUTE_GRADIENTS_FAILED = 6 +} status_t; + +typedef enum { UNDEFINED = 0, CPU = 1, GPU = 2 } device_t; + +const char* toString(status_t status); + +const char* toString(device_t device); + +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/workspace.h b/torchaudio/csrc/rnnt/workspace.h new file mode 100644 index 0000000000000000000000000000000000000000..31b57647af10424051e28e3779692f1e8b68f08f --- /dev/null +++ b/torchaudio/csrc/rnnt/workspace.h @@ -0,0 +1,223 @@ +#pragma once + +#include +#include + +#include + +namespace torchaudio { +namespace rnnt { + +// Since CUDA has strict memory alignment, it's better to keep allocated memory +// blocks separate for different data types. + +// DtypeWorkspace holds a "view" of workspace for: +// 1. softmax denominators (in log form), size = B * max_T * max_U +// 2. log probibility pairs for blank and target, size = B * max_T * max_U +// 3. alphas, size = B * max_T * max_U +// 4. betas, size = B * max_T * max_U +template +class DtypeWorkspace { + public: + DtypeWorkspace() : options_(), size_(0), data_(nullptr) {} + DtypeWorkspace(const Options& options, DTYPE* data, int size) + : DtypeWorkspace() { + Reset(options, data, size); + } + ~DtypeWorkspace() {} + + static int ComputeSizeFromOptions(const Options& options) { + CHECK_NE(options.device_, UNDEFINED); + return ComputeSizeForDenominators(options) + + ComputeSizeForLogProbs(options) + ComputeSizeForAlphas(options) + + ComputeSizeForBetas(options); + } + + void Free(); + void Reset(const Options& options, DTYPE* data, int size) { + int needed_size = ComputeSizeFromOptions(options); + CHECK_LE(needed_size, size); + options_ = options; + data_ = data; + size_ = size; + } + int Size() const { + return size_; + } + + DTYPE* GetPointerToDenominators() const { + return data_; + } + DTYPE* GetPointerToLogProbs() const { + return GetPointerToDenominators() + ComputeSizeForDenominators(options_); + } + DTYPE* GetPointerToAlphas() const { + return GetPointerToLogProbs() + ComputeSizeForLogProbs(options_); + } + DTYPE* GetPointerToBetas() const { + return GetPointerToAlphas() + ComputeSizeForAlphas(options_); + } + + private: + static int ComputeSizeForDenominators(const Options& options) { // B * T * U + return options.BTU(); + } + + static int ComputeSizeForLogProbs(const Options& options) { // B * T * U * 2 + return options.BTU() * 2; + } + + static int ComputeSizeForAlphas(const Options& options) { // B * T * U + return options.BTU(); + } + + static int ComputeSizeForBetas(const Options& options) { // B * T * U + return options.BTU(); + } + + Options options_; + int size_; // number of elements in allocated memory. + DTYPE* data_; // pointer to the allocated memory. +}; + +// IntWorkspace holds a "view" of workspace for: +// 1. alpha counters, size = B * max_U +// 2. beta counters, size = B * max_U +class IntWorkspace { + public: + IntWorkspace() : options_(), size_(0), data_(nullptr) {} + IntWorkspace(const Options& options, int* data, int size) : IntWorkspace() { + Reset(options, data, size); + } + ~IntWorkspace() {} + + static int ComputeSizeFromOptions(const Options& options) { + return ComputeSizeForAlphaCounters(options) + + ComputeSizeForBetaCounters(options); + } + + void Reset(const Options& options, int* data, int size) { + int needed_size = ComputeSizeFromOptions(options); + CHECK_LE(needed_size, size); + options_ = options; + data_ = data; + size_ = size; + ResetAlphaBetaCounters(); + } + int Size() const { + return size_; + } + + int* GetPointerToAlphaCounters() const { + CHECK_EQ(options_.device_, GPU); + return data_; + } + int* GetPointerToBetaCounters() const { + CHECK_EQ(options_.device_, GPU); + return GetPointerToAlphaCounters() + ComputeSizeForAlphaCounters(options_); + } + + private: + inline void ResetAlphaBetaCounters() { +#ifdef USE_CUDA + if (data_ != nullptr && options_.device_ == GPU) { + cudaMemset( + GetPointerToAlphaCounters(), + 0, + ComputeSizeForAlphaCounters(options_) * sizeof(int)); + cudaMemset( + GetPointerToBetaCounters(), + 0, + ComputeSizeForBetaCounters(options_) * sizeof(int)); + } +#endif // USE_CUDA + } + + static int ComputeSizeForAlphaCounters(const Options& options) { // B * U +#ifdef USE_CUDA + if (options.device_ == GPU) { + return options.BU(); + } else { + return 0; + } +#else + return 0; +#endif // USE_CUDA + } + static int ComputeSizeForBetaCounters(const Options& options) { // B * U +#ifdef USE_CUDA + if (options.device_ == GPU) { + return options.BU(); + } else { + return 0; + } +#else + return 0; +#endif // USE_CUDA + } + + Options options_; + int size_; // number of elements in allocated memory. + int* data_; // pointer to the allocated memory. +}; + +// Workspace holds: +// 1. DtypeWorkspace +// 2. IntWorkspace +template +class Workspace { + public: + Workspace() : options_(), dtype_workspace_(), int_workspace_() {} + Workspace( + const Options& options, + DTYPE* dtype_data, + int dtype_size, + int* int_data, + int int_size) + : Workspace() { + Reset(options, dtype_data, dtype_size, int_data, int_size); + } + ~Workspace() {} + + void Reset( + const Options& options, + DTYPE* dtype_data, + int dtype_size, + int* int_data, + int int_size) { + options_ = options; + dtype_workspace_.Reset(options_, dtype_data, dtype_size); + int_workspace_.Reset(options_, int_data, int_size); + } + + const Options& GetOptions() const { + return options_; + } + + DTYPE* GetPointerToDenominators() const { + return dtype_workspace_.GetPointerToDenominators(); + } + DTYPE* GetPointerToLogProbs() const { + return dtype_workspace_.GetPointerToLogProbs(); + } + DTYPE* GetPointerToAlphas() const { + return dtype_workspace_.GetPointerToAlphas(); + } + DTYPE* GetPointerToBetas() const { + return dtype_workspace_.GetPointerToBetas(); + } + int* GetPointerToAlphaCounters() const { + return int_workspace_.GetPointerToAlphaCounters(); + } + int* GetPointerToBetaCounters() const { + return int_workspace_.GetPointerToBetaCounters(); + } + + private: + Options options_; + DtypeWorkspace dtype_workspace_; + IntWorkspace int_workspace_; +}; + +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/sox/effects.cpp b/torchaudio/csrc/sox/effects.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aaa61f9298ea5de991f5596bf2f709d8c6a55174 --- /dev/null +++ b/torchaudio/csrc/sox/effects.cpp @@ -0,0 +1,155 @@ +#include +#include +#include +#include + +using namespace torchaudio::sox_utils; + +namespace torchaudio::sox_effects { + +namespace { + +enum SoxEffectsResourceState { NotInitialized, Initialized, ShutDown }; +SoxEffectsResourceState SOX_RESOURCE_STATE = NotInitialized; +std::mutex SOX_RESOUCE_STATE_MUTEX; + +} // namespace + +void initialize_sox_effects() { + const std::lock_guard lock(SOX_RESOUCE_STATE_MUTEX); + + switch (SOX_RESOURCE_STATE) { + case NotInitialized: + if (sox_init() != SOX_SUCCESS) { + throw std::runtime_error("Failed to initialize sox effects."); + }; + SOX_RESOURCE_STATE = Initialized; + break; + case Initialized: + break; + case ShutDown: + throw std::runtime_error( + "SoX Effects has been shut down. Cannot initialize again."); + } +}; + +void shutdown_sox_effects() { + const std::lock_guard lock(SOX_RESOUCE_STATE_MUTEX); + + switch (SOX_RESOURCE_STATE) { + case NotInitialized: + throw std::runtime_error( + "SoX Effects is not initialized. Cannot shutdown."); + case Initialized: + if (sox_quit() != SOX_SUCCESS) { + throw std::runtime_error("Failed to initialize sox effects."); + }; + SOX_RESOURCE_STATE = ShutDown; + break; + case ShutDown: + break; + } +} + +auto apply_effects_tensor( + torch::Tensor waveform, + int64_t sample_rate, + const std::vector>& effects, + bool channels_first) -> std::tuple { + validate_input_tensor(waveform); + + // Create SoxEffectsChain + const auto dtype = waveform.dtype(); + torchaudio::sox_effects_chain::SoxEffectsChain chain( + /*input_encoding=*/get_tensor_encodinginfo(dtype), + /*output_encoding=*/get_tensor_encodinginfo(dtype)); + + // Prepare output buffer + std::vector out_buffer; + out_buffer.reserve(waveform.numel()); + + // Build and run effects chain + chain.addInputTensor(&waveform, sample_rate, channels_first); + for (const auto& effect : effects) { + chain.addEffect(effect); + } + chain.addOutputBuffer(&out_buffer); + chain.run(); + + // Create tensor from buffer + auto out_tensor = convert_to_tensor( + /*buffer=*/out_buffer.data(), + /*num_samples=*/out_buffer.size(), + /*num_channels=*/chain.getOutputNumChannels(), + dtype, + /*normalize=*/false, + channels_first); + + return std::tuple( + out_tensor, chain.getOutputSampleRate()); +} + +auto apply_effects_file( + const std::string& path, + const std::vector>& effects, + c10::optional normalize, + c10::optional channels_first, + const c10::optional& format) + -> std::tuple { + // Open input file + SoxFormat sf(sox_open_read( + path.c_str(), + /*signal=*/nullptr, + /*encoding=*/nullptr, + /*filetype=*/format.has_value() ? format.value().c_str() : nullptr)); + + validate_input_file(sf, path); + + const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision); + + // Prepare output + std::vector out_buffer; + out_buffer.reserve(sf->signal.length); + + // Create and run SoxEffectsChain + torchaudio::sox_effects_chain::SoxEffectsChain chain( + /*input_encoding=*/sf->encoding, + /*output_encoding=*/get_tensor_encodinginfo(dtype)); + + chain.addInputFile(sf); + for (const auto& effect : effects) { + chain.addEffect(effect); + } + chain.addOutputBuffer(&out_buffer); + chain.run(); + + // Create tensor from buffer + bool channels_first_ = channels_first.value_or(true); + auto tensor = convert_to_tensor( + /*buffer=*/out_buffer.data(), + /*num_samples=*/out_buffer.size(), + /*num_channels=*/chain.getOutputNumChannels(), + dtype, + normalize.value_or(true), + channels_first_); + + return std::tuple( + tensor, chain.getOutputSampleRate()); +} + +TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def( + "torchaudio::sox_effects_initialize_sox_effects", + &torchaudio::sox_effects::initialize_sox_effects); + m.def( + "torchaudio::sox_effects_shutdown_sox_effects", + &torchaudio::sox_effects::shutdown_sox_effects); + m.def( + "torchaudio::sox_effects_apply_effects_tensor", + &torchaudio::sox_effects::apply_effects_tensor); + m.def( + "torchaudio::sox_effects_apply_effects_file", + &torchaudio::sox_effects::apply_effects_file); +} + +} // namespace torchaudio::sox_effects diff --git a/torchaudio/csrc/sox/effects.h b/torchaudio/csrc/sox/effects.h new file mode 100644 index 0000000000000000000000000000000000000000..71c0c7787c974a50fa87d60271036d33e88c6620 --- /dev/null +++ b/torchaudio/csrc/sox/effects.h @@ -0,0 +1,29 @@ +#ifndef TORCHAUDIO_SOX_EFFECTS_H +#define TORCHAUDIO_SOX_EFFECTS_H + +#include +#include + +namespace torchaudio::sox_effects { + +void initialize_sox_effects(); + +void shutdown_sox_effects(); + +auto apply_effects_tensor( + torch::Tensor waveform, + int64_t sample_rate, + const std::vector>& effects, + bool channels_first) -> std::tuple; + +auto apply_effects_file( + const std::string& path, + const std::vector>& effects, + c10::optional normalize, + c10::optional channels_first, + const c10::optional& format) + -> std::tuple; + +} // namespace torchaudio::sox_effects + +#endif diff --git a/torchaudio/csrc/sox/effects_chain.cpp b/torchaudio/csrc/sox/effects_chain.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b141e8c8d960755d7594b8dc14dc59b52024a89b --- /dev/null +++ b/torchaudio/csrc/sox/effects_chain.cpp @@ -0,0 +1,323 @@ +#include +#include + +using namespace torch::indexing; +using namespace torchaudio::sox_utils; + +namespace torchaudio { +namespace sox_effects_chain { + +namespace { + +/// helper classes for passing the location of input tensor and output buffer +/// +/// drain/flow callback functions require plaing C style function signature and +/// the way to pass extra data is to attach data to sox_effect_t::priv pointer. +/// The following structs will be assigned to sox_effect_t::priv pointer which +/// gives sox_effect_t an access to input Tensor and output buffer object. +struct TensorInputPriv { + size_t index; + torch::Tensor* waveform; + int64_t sample_rate; + bool channels_first; +}; +struct TensorOutputPriv { + std::vector* buffer; +}; +struct FileOutputPriv { + sox_format_t* sf; +}; + +/// Callback function to feed Tensor data to SoxEffectChain. +int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { + // Retrieve the input Tensor and current index + auto priv = static_cast(effp->priv); + auto index = priv->index; + auto tensor = *(priv->waveform); + auto num_channels = effp->out_signal.channels; + + // Adjust the number of samples to read + const size_t num_samples = tensor.numel(); + if (index + *osamp > num_samples) { + *osamp = num_samples - index; + } + // Ensure that it's a multiple of the number of channels + *osamp -= *osamp % num_channels; + + // Slice the input Tensor + auto chunk = [&]() { + auto i_frame = index / num_channels; + auto num_frames = *osamp / num_channels; + auto t = (priv->channels_first) + ? tensor.index({Slice(), Slice(i_frame, i_frame + num_frames)}).t() + : tensor.index({Slice(i_frame, i_frame + num_frames), Slice()}); + return t.reshape({-1}); + }(); + + // Convert to sox_sample_t (int32_t) + switch (chunk.dtype().toScalarType()) { + case c10::ScalarType::Float: { + // Need to convert to 64-bit precision so that + // values around INT32_MIN/MAX are handled correctly. + chunk = chunk.to(c10::ScalarType::Double); + chunk *= 2147483648.; + chunk.clamp_(INT32_MIN, INT32_MAX); + chunk = chunk.to(c10::ScalarType::Int); + break; + } + case c10::ScalarType::Int: { + break; + } + case c10::ScalarType::Short: { + chunk = chunk.to(c10::ScalarType::Int); + chunk *= 65536; + break; + } + case c10::ScalarType::Byte: { + chunk = chunk.to(c10::ScalarType::Int); + chunk -= 128; + chunk *= 16777216; + break; + } + default: + throw std::runtime_error("Unexpected dtype."); + } + // Write to buffer + chunk = chunk.contiguous(); + memcpy(obuf, chunk.data_ptr(), *osamp * 4); + priv->index += *osamp; + return (priv->index == num_samples) ? SOX_EOF : SOX_SUCCESS; +} + +/// Callback function to fetch data from SoxEffectChain. +int tensor_output_flow( + sox_effect_t* effp, + sox_sample_t const* ibuf, + sox_sample_t* obuf LSX_UNUSED, + size_t* isamp, + size_t* osamp) { + *osamp = 0; + // Get output buffer + auto out_buffer = static_cast(effp->priv)->buffer; + // Append at the end + out_buffer->insert(out_buffer->end(), ibuf, ibuf + *isamp); + return SOX_SUCCESS; +} + +int file_output_flow( + sox_effect_t* effp, + sox_sample_t const* ibuf, + sox_sample_t* obuf LSX_UNUSED, + size_t* isamp, + size_t* osamp) { + *osamp = 0; + if (*isamp) { + auto sf = static_cast(effp->priv)->sf; + if (sox_write(sf, ibuf, *isamp) != *isamp) { + if (sf->sox_errno) { + std::ostringstream stream; + stream << sf->sox_errstr << " " << sox_strerror(sf->sox_errno) << " " + << sf->filename; + throw std::runtime_error(stream.str()); + } + return SOX_EOF; + } + } + return SOX_SUCCESS; +} + +sox_effect_handler_t* get_tensor_input_handler() { + static sox_effect_handler_t handler{ + /*name=*/"input_tensor", + /*usage=*/NULL, + /*flags=*/SOX_EFF_MCHAN, + /*getopts=*/NULL, + /*start=*/NULL, + /*flow=*/NULL, + /*drain=*/tensor_input_drain, + /*stop=*/NULL, + /*kill=*/NULL, + /*priv_size=*/sizeof(TensorInputPriv)}; + return &handler; +} + +sox_effect_handler_t* get_tensor_output_handler() { + static sox_effect_handler_t handler{ + /*name=*/"output_tensor", + /*usage=*/NULL, + /*flags=*/SOX_EFF_MCHAN, + /*getopts=*/NULL, + /*start=*/NULL, + /*flow=*/tensor_output_flow, + /*drain=*/NULL, + /*stop=*/NULL, + /*kill=*/NULL, + /*priv_size=*/sizeof(TensorOutputPriv)}; + return &handler; +} + +sox_effect_handler_t* get_file_output_handler() { + static sox_effect_handler_t handler{ + /*name=*/"output_file", + /*usage=*/NULL, + /*flags=*/SOX_EFF_MCHAN, + /*getopts=*/NULL, + /*start=*/NULL, + /*flow=*/file_output_flow, + /*drain=*/NULL, + /*stop=*/NULL, + /*kill=*/NULL, + /*priv_size=*/sizeof(FileOutputPriv)}; + return &handler; +} + +} // namespace + +SoxEffect::SoxEffect(sox_effect_t* se) noexcept : se_(se) {} + +SoxEffect::~SoxEffect() { + if (se_ != nullptr) { + free(se_); + } +} + +SoxEffect::operator sox_effect_t*() const { + return se_; +} + +auto SoxEffect::operator->() noexcept -> sox_effect_t* { + return se_; +} + +SoxEffectsChain::SoxEffectsChain( + sox_encodinginfo_t input_encoding, + sox_encodinginfo_t output_encoding) + : in_enc_(input_encoding), + out_enc_(output_encoding), + in_sig_(), + interm_sig_(), + out_sig_(), + sec_(sox_create_effects_chain(&in_enc_, &out_enc_)) { + if (!sec_) { + throw std::runtime_error("Failed to create effect chain."); + } +} + +SoxEffectsChain::~SoxEffectsChain() { + if (sec_ != nullptr) { + sox_delete_effects_chain(sec_); + } +} + +void SoxEffectsChain::run() { + sox_flow_effects(sec_, NULL, NULL); +} + +void SoxEffectsChain::addInputTensor( + torch::Tensor* waveform, + int64_t sample_rate, + bool channels_first) { + in_sig_ = get_signalinfo(waveform, sample_rate, "wav", channels_first); + interm_sig_ = in_sig_; + SoxEffect e(sox_create_effect(get_tensor_input_handler())); + auto priv = static_cast(e->priv); + priv->index = 0; + priv->waveform = waveform; + priv->sample_rate = sample_rate; + priv->channels_first = channels_first; + if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { + throw std::runtime_error( + "Internal Error: Failed to add effect: input_tensor"); + } +} + +void SoxEffectsChain::addOutputBuffer( + std::vector* output_buffer) { + SoxEffect e(sox_create_effect(get_tensor_output_handler())); + static_cast(e->priv)->buffer = output_buffer; + if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { + throw std::runtime_error( + "Internal Error: Failed to add effect: output_tensor"); + } +} + +void SoxEffectsChain::addInputFile(sox_format_t* sf) { + in_sig_ = sf->signal; + interm_sig_ = in_sig_; + SoxEffect e(sox_create_effect(sox_find_effect("input"))); + char* opts[] = {(char*)sf}; + sox_effect_options(e, 1, opts); + if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { + std::ostringstream stream; + stream << "Internal Error: Failed to add effect: input " << sf->filename; + throw std::runtime_error(stream.str()); + } +} + +void SoxEffectsChain::addOutputFile(sox_format_t* sf) { + out_sig_ = sf->signal; + SoxEffect e(sox_create_effect(get_file_output_handler())); + static_cast(e->priv)->sf = sf; + if (sox_add_effect(sec_, e, &interm_sig_, &out_sig_) != SOX_SUCCESS) { + std::ostringstream stream; + stream << "Internal Error: Failed to add effect: output " << sf->filename; + throw std::runtime_error(stream.str()); + } +} + +void SoxEffectsChain::addEffect(const std::vector effect) { + const auto num_args = effect.size(); + if (num_args == 0) { + throw std::runtime_error("Invalid argument: empty effect."); + } + const auto name = effect[0]; + if (UNSUPPORTED_EFFECTS.find(name) != UNSUPPORTED_EFFECTS.end()) { + std::ostringstream stream; + stream << "Unsupported effect: " << name; + throw std::runtime_error(stream.str()); + } + + auto returned_effect = sox_find_effect(name.c_str()); + if (!returned_effect) { + std::ostringstream stream; + stream << "Unsupported effect: " << name; + throw std::runtime_error(stream.str()); + } + SoxEffect e(sox_create_effect(returned_effect)); + const auto num_options = num_args - 1; + + std::vector opts; + for (size_t i = 1; i < num_args; ++i) { + opts.push_back((char*)effect[i].c_str()); + } + if (sox_effect_options(e, num_options, num_options ? opts.data() : nullptr) != + SOX_SUCCESS) { + std::ostringstream stream; + stream << "Invalid effect option:"; + for (const auto& v : effect) { + stream << " " << v; + } + throw std::runtime_error(stream.str()); + } + + if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { + std::ostringstream stream; + stream << "Internal Error: Failed to add effect: \"" << name; + for (size_t i = 1; i < num_args; ++i) { + stream << " " << effect[i]; + } + stream << "\""; + throw std::runtime_error(stream.str()); + } +} + +int64_t SoxEffectsChain::getOutputNumChannels() { + return interm_sig_.channels; +} + +int64_t SoxEffectsChain::getOutputSampleRate() { + return interm_sig_.rate; +} + +} // namespace sox_effects_chain +} // namespace torchaudio diff --git a/torchaudio/csrc/sox/effects_chain.h b/torchaudio/csrc/sox/effects_chain.h new file mode 100644 index 0000000000000000000000000000000000000000..c456276ef06881f836fb9cf4c0eacf1236c08e0b --- /dev/null +++ b/torchaudio/csrc/sox/effects_chain.h @@ -0,0 +1,63 @@ +#ifndef TORCHAUDIO_SOX_EFFECTS_CHAIN_H +#define TORCHAUDIO_SOX_EFFECTS_CHAIN_H + +#include +#include + +namespace torchaudio { +namespace sox_effects_chain { + +// Helper struct to safely close sox_effect_t* pointer returned by +// sox_create_effect + +struct SoxEffect { + explicit SoxEffect(sox_effect_t* se) noexcept; + SoxEffect(const SoxEffect& other) = delete; + SoxEffect(const SoxEffect&& other) = delete; + auto operator=(const SoxEffect& other) -> SoxEffect& = delete; + auto operator=(SoxEffect&& other) -> SoxEffect& = delete; + ~SoxEffect(); + operator sox_effect_t*() const; + auto operator->() noexcept -> sox_effect_t*; + + private: + sox_effect_t* se_; +}; + +// Helper struct to safely close sox_effects_chain_t with handy methods +class SoxEffectsChain { + const sox_encodinginfo_t in_enc_; + const sox_encodinginfo_t out_enc_; + + protected: + sox_signalinfo_t in_sig_; + sox_signalinfo_t interm_sig_; + sox_signalinfo_t out_sig_; + sox_effects_chain_t* sec_; + + public: + explicit SoxEffectsChain( + sox_encodinginfo_t input_encoding, + sox_encodinginfo_t output_encoding); + SoxEffectsChain(const SoxEffectsChain& other) = delete; + SoxEffectsChain(const SoxEffectsChain&& other) = delete; + SoxEffectsChain& operator=(const SoxEffectsChain& other) = delete; + SoxEffectsChain& operator=(SoxEffectsChain&& other) = delete; + ~SoxEffectsChain(); + void run(); + void addInputTensor( + torch::Tensor* waveform, + int64_t sample_rate, + bool channels_first); + void addInputFile(sox_format_t* sf); + void addOutputBuffer(std::vector* output_buffer); + void addOutputFile(sox_format_t* sf); + void addEffect(const std::vector effect); + int64_t getOutputNumChannels(); + int64_t getOutputSampleRate(); +}; + +} // namespace sox_effects_chain +} // namespace torchaudio + +#endif diff --git a/torchaudio/csrc/sox/io.cpp b/torchaudio/csrc/sox/io.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f86f121ad344448c47348dfe2c499357c9165839 --- /dev/null +++ b/torchaudio/csrc/sox/io.cpp @@ -0,0 +1,143 @@ +#include +#include +#include +#include +#include + +using namespace torch::indexing; +using namespace torchaudio::sox_utils; + +namespace torchaudio { +namespace sox_io { + +std::tuple get_info_file( + const std::string& path, + const c10::optional& format) { + SoxFormat sf(sox_open_read( + path.c_str(), + /*signal=*/nullptr, + /*encoding=*/nullptr, + /*filetype=*/format.has_value() ? format.value().c_str() : nullptr)); + + validate_input_file(sf, path); + + return std::make_tuple( + static_cast(sf->signal.rate), + static_cast(sf->signal.length / sf->signal.channels), + static_cast(sf->signal.channels), + static_cast(sf->encoding.bits_per_sample), + get_encoding(sf->encoding.encoding)); +} + +std::vector> get_effects( + const c10::optional& frame_offset, + const c10::optional& num_frames) { + const auto offset = frame_offset.value_or(0); + if (offset < 0) { + throw std::runtime_error( + "Invalid argument: frame_offset must be non-negative."); + } + const auto frames = num_frames.value_or(-1); + if (frames == 0 || frames < -1) { + throw std::runtime_error( + "Invalid argument: num_frames must be -1 or greater than 0."); + } + + std::vector> effects; + if (frames != -1) { + std::ostringstream os_offset, os_frames; + os_offset << offset << "s"; + os_frames << "+" << frames << "s"; + effects.emplace_back( + std::vector{"trim", os_offset.str(), os_frames.str()}); + } else if (offset != 0) { + std::ostringstream os_offset; + os_offset << offset << "s"; + effects.emplace_back(std::vector{"trim", os_offset.str()}); + } + return effects; +} + +std::tuple load_audio_file( + const std::string& path, + const c10::optional& frame_offset, + const c10::optional& num_frames, + c10::optional normalize, + c10::optional channels_first, + const c10::optional& format) { + auto effects = get_effects(frame_offset, num_frames); + return torchaudio::sox_effects::apply_effects_file( + path, effects, normalize, channels_first, format); +} + +void save_audio_file( + const std::string& path, + torch::Tensor tensor, + int64_t sample_rate, + bool channels_first, + c10::optional compression, + c10::optional format, + c10::optional encoding, + c10::optional bits_per_sample) { + validate_input_tensor(tensor); + + const auto filetype = [&]() { + if (format.has_value()) + return format.value(); + return get_filetype(path); + }(); + + if (filetype == "amr-nb") { + const auto num_channels = tensor.size(channels_first ? 0 : 1); + TORCH_CHECK( + num_channels == 1, "amr-nb format only supports single channel audio."); + } else if (filetype == "htk") { + const auto num_channels = tensor.size(channels_first ? 0 : 1); + TORCH_CHECK( + num_channels == 1, "htk format only supports single channel audio."); + } else if (filetype == "gsm") { + const auto num_channels = tensor.size(channels_first ? 0 : 1); + TORCH_CHECK( + num_channels == 1, "gsm format only supports single channel audio."); + TORCH_CHECK( + sample_rate == 8000, + "gsm format only supports a sampling rate of 8kHz."); + } + const auto signal_info = + get_signalinfo(&tensor, sample_rate, filetype, channels_first); + const auto encoding_info = get_encodinginfo_for_save( + filetype, tensor.dtype(), compression, encoding, bits_per_sample); + + SoxFormat sf(sox_open_write( + path.c_str(), + &signal_info, + &encoding_info, + /*filetype=*/filetype.c_str(), + /*oob=*/nullptr, + /*overwrite_permitted=*/nullptr)); + + if (static_cast(sf) == nullptr) { + throw std::runtime_error( + "Error saving audio file: failed to open file " + path); + } + + torchaudio::sox_effects_chain::SoxEffectsChain chain( + /*input_encoding=*/get_tensor_encodinginfo(tensor.dtype()), + /*output_encoding=*/sf->encoding); + chain.addInputTensor(&tensor, sample_rate, channels_first); + chain.addOutputFile(sf); + chain.run(); +} + +TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info_file); + m.def( + "torchaudio::sox_io_load_audio_file", + &torchaudio::sox_io::load_audio_file); + m.def( + "torchaudio::sox_io_save_audio_file", + &torchaudio::sox_io::save_audio_file); +} + +} // namespace sox_io +} // namespace torchaudio diff --git a/torchaudio/csrc/sox/io.h b/torchaudio/csrc/sox/io.h new file mode 100644 index 0000000000000000000000000000000000000000..e6c8cffba5fceca555f7212efdeb8d853e5c87aa --- /dev/null +++ b/torchaudio/csrc/sox/io.h @@ -0,0 +1,40 @@ +#ifndef TORCHAUDIO_SOX_IO_H +#define TORCHAUDIO_SOX_IO_H + +#include +#include + +namespace torchaudio { +namespace sox_io { + +auto get_effects( + const c10::optional& frame_offset, + const c10::optional& num_frames) + -> std::vector>; + +std::tuple get_info_file( + const std::string& path, + const c10::optional& format); + +std::tuple load_audio_file( + const std::string& path, + const c10::optional& frame_offset, + const c10::optional& num_frames, + c10::optional normalize, + c10::optional channels_first, + const c10::optional& format); + +void save_audio_file( + const std::string& path, + torch::Tensor tensor, + int64_t sample_rate, + bool channels_first, + c10::optional compression, + c10::optional format, + c10::optional encoding, + c10::optional bits_per_sample); + +} // namespace sox_io +} // namespace torchaudio + +#endif diff --git a/torchaudio/csrc/sox/types.cpp b/torchaudio/csrc/sox/types.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b4bbf710f6e1df32fa858a29885a694f424709e0 --- /dev/null +++ b/torchaudio/csrc/sox/types.cpp @@ -0,0 +1,139 @@ +#include + +namespace torchaudio { +namespace sox_utils { + +Format get_format_from_string(const std::string& format) { + if (format == "wav") + return Format::WAV; + if (format == "mp3") + return Format::MP3; + if (format == "flac") + return Format::FLAC; + if (format == "ogg" || format == "vorbis") + return Format::VORBIS; + if (format == "amr-nb") + return Format::AMR_NB; + if (format == "amr-wb") + return Format::AMR_WB; + if (format == "amb") + return Format::AMB; + if (format == "sph") + return Format::SPHERE; + if (format == "htk") + return Format::HTK; + if (format == "gsm") + return Format::GSM; + std::ostringstream stream; + stream << "Internal Error: unexpected format value: " << format; + throw std::runtime_error(stream.str()); +} + +std::string to_string(Encoding v) { + switch (v) { + case Encoding::UNKNOWN: + return "UNKNOWN"; + case Encoding::PCM_SIGNED: + return "PCM_S"; + case Encoding::PCM_UNSIGNED: + return "PCM_U"; + case Encoding::PCM_FLOAT: + return "PCM_F"; + case Encoding::FLAC: + return "FLAC"; + case Encoding::ULAW: + return "ULAW"; + case Encoding::ALAW: + return "ALAW"; + case Encoding::MP3: + return "MP3"; + case Encoding::VORBIS: + return "VORBIS"; + case Encoding::AMR_WB: + return "AMR_WB"; + case Encoding::AMR_NB: + return "AMR_NB"; + case Encoding::OPUS: + return "OPUS"; + default: + throw std::runtime_error("Internal Error: unexpected encoding."); + } +} + +Encoding get_encoding_from_option(const c10::optional encoding) { + if (!encoding.has_value()) + return Encoding::NOT_PROVIDED; + std::string v = encoding.value(); + if (v == "PCM_S") + return Encoding::PCM_SIGNED; + if (v == "PCM_U") + return Encoding::PCM_UNSIGNED; + if (v == "PCM_F") + return Encoding::PCM_FLOAT; + if (v == "ULAW") + return Encoding::ULAW; + if (v == "ALAW") + return Encoding::ALAW; + std::ostringstream stream; + stream << "Internal Error: unexpected encoding value: " << v; + throw std::runtime_error(stream.str()); +} + +BitDepth get_bit_depth_from_option(const c10::optional bit_depth) { + if (!bit_depth.has_value()) + return BitDepth::NOT_PROVIDED; + int64_t v = bit_depth.value(); + switch (v) { + case 8: + return BitDepth::B8; + case 16: + return BitDepth::B16; + case 24: + return BitDepth::B24; + case 32: + return BitDepth::B32; + case 64: + return BitDepth::B64; + default: { + std::ostringstream s; + s << "Internal Error: unexpected bit depth value: " << v; + throw std::runtime_error(s.str()); + } + } +} + +std::string get_encoding(sox_encoding_t encoding) { + switch (encoding) { + case SOX_ENCODING_UNKNOWN: + return "UNKNOWN"; + case SOX_ENCODING_SIGN2: + return "PCM_S"; + case SOX_ENCODING_UNSIGNED: + return "PCM_U"; + case SOX_ENCODING_FLOAT: + return "PCM_F"; + case SOX_ENCODING_FLAC: + return "FLAC"; + case SOX_ENCODING_ULAW: + return "ULAW"; + case SOX_ENCODING_ALAW: + return "ALAW"; + case SOX_ENCODING_MP3: + return "MP3"; + case SOX_ENCODING_VORBIS: + return "VORBIS"; + case SOX_ENCODING_AMR_WB: + return "AMR_WB"; + case SOX_ENCODING_AMR_NB: + return "AMR_NB"; + case SOX_ENCODING_OPUS: + return "OPUS"; + case SOX_ENCODING_GSM: + return "GSM"; + default: + return "UNKNOWN"; + } +} + +} // namespace sox_utils +} // namespace torchaudio diff --git a/torchaudio/csrc/sox/types.h b/torchaudio/csrc/sox/types.h new file mode 100644 index 0000000000000000000000000000000000000000..afd84791a69bb10b3c900ecfea18e6700cac4e66 --- /dev/null +++ b/torchaudio/csrc/sox/types.h @@ -0,0 +1,60 @@ +#ifndef TORCHAUDIO_SOX_TYPES_H +#define TORCHAUDIO_SOX_TYPES_H + +#include +#include + +namespace torchaudio { +namespace sox_utils { + +enum class Format { + WAV, + MP3, + FLAC, + VORBIS, + AMR_NB, + AMR_WB, + AMB, + SPHERE, + GSM, + HTK, +}; + +Format get_format_from_string(const std::string& format); + +enum class Encoding { + NOT_PROVIDED, + UNKNOWN, + PCM_SIGNED, + PCM_UNSIGNED, + PCM_FLOAT, + FLAC, + ULAW, + ALAW, + MP3, + VORBIS, + AMR_WB, + AMR_NB, + OPUS, +}; + +std::string to_string(Encoding v); +Encoding get_encoding_from_option(const c10::optional encoding); + +enum class BitDepth : unsigned { + NOT_PROVIDED = 0, + B8 = 8, + B16 = 16, + B24 = 24, + B32 = 32, + B64 = 64, +}; + +BitDepth get_bit_depth_from_option(const c10::optional bit_depth); + +std::string get_encoding(sox_encoding_t encoding); + +} // namespace sox_utils +} // namespace torchaudio + +#endif diff --git a/torchaudio/csrc/sox/utils.cpp b/torchaudio/csrc/sox/utils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2208ed2074ca9fe36d9f84633bfe3482be5fa30b --- /dev/null +++ b/torchaudio/csrc/sox/utils.cpp @@ -0,0 +1,522 @@ +#include +#include +#include +#include + +namespace torchaudio { +namespace sox_utils { + +void set_seed(const int64_t seed) { + sox_get_globals()->ranqd1 = static_cast(seed); +} + +void set_verbosity(const int64_t verbosity) { + sox_get_globals()->verbosity = static_cast(verbosity); +} + +void set_use_threads(const bool use_threads) { + sox_get_globals()->use_threads = static_cast(use_threads); +} + +void set_buffer_size(const int64_t buffer_size) { + sox_get_globals()->bufsiz = static_cast(buffer_size); +} + +int64_t get_buffer_size() { + return sox_get_globals()->bufsiz; +} + +std::vector> list_effects() { + std::vector> effects; + for (const sox_effect_fn_t* fns = sox_get_effect_fns(); *fns; ++fns) { + const sox_effect_handler_t* handler = (*fns)(); + if (handler && handler->name) { + if (UNSUPPORTED_EFFECTS.find(handler->name) == + UNSUPPORTED_EFFECTS.end()) { + effects.emplace_back(std::vector{ + handler->name, + handler->usage ? std::string(handler->usage) : std::string("")}); + } + } + } + return effects; +} + +std::vector list_write_formats() { + std::vector formats; + for (const sox_format_tab_t* fns = sox_get_format_fns(); fns->fn; ++fns) { + const sox_format_handler_t* handler = fns->fn(); + for (const char* const* names = handler->names; *names; ++names) { + if (!strchr(*names, '/') && handler->write) + formats.emplace_back(*names); + } + } + return formats; +} + +std::vector list_read_formats() { + std::vector formats; + for (const sox_format_tab_t* fns = sox_get_format_fns(); fns->fn; ++fns) { + const sox_format_handler_t* handler = fns->fn(); + for (const char* const* names = handler->names; *names; ++names) { + if (!strchr(*names, '/') && handler->read) + formats.emplace_back(*names); + } + } + return formats; +} + +SoxFormat::SoxFormat(sox_format_t* fd) noexcept : fd_(fd) {} +SoxFormat::~SoxFormat() { + close(); +} + +sox_format_t* SoxFormat::operator->() const noexcept { + return fd_; +} +SoxFormat::operator sox_format_t*() const noexcept { + return fd_; +} + +void SoxFormat::close() { + if (fd_ != nullptr) { + sox_close(fd_); + fd_ = nullptr; + } +} + +void validate_input_file(const SoxFormat& sf, const std::string& path) { + if (static_cast(sf) == nullptr) { + throw std::runtime_error( + "Error loading audio file: failed to open file " + path); + } + if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) { + throw std::runtime_error("Error loading audio file: unknown encoding."); + } +} + +void validate_input_memfile(const SoxFormat& sf) { + return validate_input_file(sf, ""); +} + +void validate_input_tensor(const torch::Tensor tensor) { + if (!tensor.device().is_cpu()) { + throw std::runtime_error("Input tensor has to be on CPU."); + } + + if (tensor.ndimension() != 2) { + throw std::runtime_error("Input tensor has to be 2D."); + } + + switch (tensor.dtype().toScalarType()) { + case c10::ScalarType::Byte: + case c10::ScalarType::Short: + case c10::ScalarType::Int: + case c10::ScalarType::Float: + break; + default: + throw std::runtime_error( + "Input tensor has to be one of float32, int32, int16 or uint8 type."); + } +} + +caffe2::TypeMeta get_dtype( + const sox_encoding_t encoding, + const unsigned precision) { + const auto dtype = [&]() { + switch (encoding) { + case SOX_ENCODING_UNSIGNED: // 8-bit PCM WAV + return torch::kUInt8; + case SOX_ENCODING_SIGN2: // 16-bit, 24-bit, or 32-bit PCM WAV + switch (precision) { + case 16: + return torch::kInt16; + case 24: // Cast 24-bit to 32-bit. + case 32: + return torch::kInt32; + default: + throw std::runtime_error( + "Only 16, 24, and 32 bits are supported for signed PCM."); + } + default: + // default to float32 for the other formats, including + // 32-bit flaoting-point WAV, + // MP3, + // FLAC, + // VORBIS etc... + return torch::kFloat32; + } + }(); + return c10::scalarTypeToTypeMeta(dtype); +} + +torch::Tensor convert_to_tensor( + sox_sample_t* buffer, + const int32_t num_samples, + const int32_t num_channels, + const caffe2::TypeMeta dtype, + const bool normalize, + const bool channels_first) { + torch::Tensor t; + uint64_t dummy = 0; + SOX_SAMPLE_LOCALS; + if (normalize || dtype == torch::kFloat32) { + t = torch::empty( + {num_samples / num_channels, num_channels}, torch::kFloat32); + auto ptr = t.data_ptr(); + for (int32_t i = 0; i < num_samples; ++i) { + ptr[i] = SOX_SAMPLE_TO_FLOAT_32BIT(buffer[i], dummy); + } + } else if (dtype == torch::kInt32) { + t = torch::from_blob( + buffer, {num_samples / num_channels, num_channels}, torch::kInt32) + .clone(); + } else if (dtype == torch::kInt16) { + t = torch::empty({num_samples / num_channels, num_channels}, torch::kInt16); + auto ptr = t.data_ptr(); + for (int32_t i = 0; i < num_samples; ++i) { + ptr[i] = SOX_SAMPLE_TO_SIGNED_16BIT(buffer[i], dummy); + } + } else if (dtype == torch::kUInt8) { + t = torch::empty({num_samples / num_channels, num_channels}, torch::kUInt8); + auto ptr = t.data_ptr(); + for (int32_t i = 0; i < num_samples; ++i) { + ptr[i] = SOX_SAMPLE_TO_UNSIGNED_8BIT(buffer[i], dummy); + } + } else { + throw std::runtime_error("Unsupported dtype."); + } + if (channels_first) { + t = t.transpose(1, 0); + } + return t.contiguous(); +} + +const std::string get_filetype(const std::string path) { + std::string ext = path.substr(path.find_last_of(".") + 1); + std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); + return ext; +} + +namespace { + +std::tuple get_save_encoding_for_wav( + const std::string format, + caffe2::TypeMeta dtype, + const Encoding& encoding, + const BitDepth& bits_per_sample) { + switch (encoding) { + case Encoding::NOT_PROVIDED: + switch (bits_per_sample) { + case BitDepth::NOT_PROVIDED: + switch (dtype.toScalarType()) { + case c10::ScalarType::Float: + return std::make_tuple<>(SOX_ENCODING_FLOAT, 32); + case c10::ScalarType::Int: + return std::make_tuple<>(SOX_ENCODING_SIGN2, 32); + case c10::ScalarType::Short: + return std::make_tuple<>(SOX_ENCODING_SIGN2, 16); + case c10::ScalarType::Byte: + return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); + default: + throw std::runtime_error("Internal Error: Unexpected dtype."); + } + case BitDepth::B8: + return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); + default: + return std::make_tuple<>( + SOX_ENCODING_SIGN2, static_cast(bits_per_sample)); + } + case Encoding::PCM_SIGNED: + switch (bits_per_sample) { + case BitDepth::NOT_PROVIDED: + return std::make_tuple<>(SOX_ENCODING_SIGN2, 32); + case BitDepth::B8: + throw std::runtime_error( + format + " does not support 8-bit signed PCM encoding."); + default: + return std::make_tuple<>( + SOX_ENCODING_SIGN2, static_cast(bits_per_sample)); + } + case Encoding::PCM_UNSIGNED: + switch (bits_per_sample) { + case BitDepth::NOT_PROVIDED: + case BitDepth::B8: + return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); + default: + throw std::runtime_error( + format + " only supports 8-bit for unsigned PCM encoding."); + } + case Encoding::PCM_FLOAT: + switch (bits_per_sample) { + case BitDepth::NOT_PROVIDED: + case BitDepth::B32: + return std::make_tuple<>(SOX_ENCODING_FLOAT, 32); + case BitDepth::B64: + return std::make_tuple<>(SOX_ENCODING_FLOAT, 64); + default: + throw std::runtime_error( + format + + " only supports 32-bit or 64-bit for floating-point PCM encoding."); + } + case Encoding::ULAW: + switch (bits_per_sample) { + case BitDepth::NOT_PROVIDED: + case BitDepth::B8: + return std::make_tuple<>(SOX_ENCODING_ULAW, 8); + default: + throw std::runtime_error( + format + " only supports 8-bit for mu-law encoding."); + } + case Encoding::ALAW: + switch (bits_per_sample) { + case BitDepth::NOT_PROVIDED: + case BitDepth::B8: + return std::make_tuple<>(SOX_ENCODING_ALAW, 8); + default: + throw std::runtime_error( + format + " only supports 8-bit for a-law encoding."); + } + default: + throw std::runtime_error( + format + " does not support encoding: " + to_string(encoding)); + } +} + +std::tuple get_save_encoding( + const std::string& format, + const caffe2::TypeMeta dtype, + const c10::optional encoding, + const c10::optional bits_per_sample) { + const Format fmt = get_format_from_string(format); + const Encoding enc = get_encoding_from_option(encoding); + const BitDepth bps = get_bit_depth_from_option(bits_per_sample); + + switch (fmt) { + case Format::WAV: + case Format::AMB: + return get_save_encoding_for_wav(format, dtype, enc, bps); + case Format::MP3: + if (enc != Encoding::NOT_PROVIDED) + throw std::runtime_error("mp3 does not support `encoding` option."); + if (bps != BitDepth::NOT_PROVIDED) + throw std::runtime_error( + "mp3 does not support `bits_per_sample` option."); + return std::make_tuple<>(SOX_ENCODING_MP3, 16); + case Format::HTK: + if (enc != Encoding::NOT_PROVIDED) + throw std::runtime_error("htk does not support `encoding` option."); + if (bps != BitDepth::NOT_PROVIDED) + throw std::runtime_error( + "htk does not support `bits_per_sample` option."); + return std::make_tuple<>(SOX_ENCODING_SIGN2, 16); + case Format::VORBIS: + if (enc != Encoding::NOT_PROVIDED) + throw std::runtime_error("vorbis does not support `encoding` option."); + if (bps != BitDepth::NOT_PROVIDED) + throw std::runtime_error( + "vorbis does not support `bits_per_sample` option."); + return std::make_tuple<>(SOX_ENCODING_VORBIS, 16); + case Format::AMR_NB: + if (enc != Encoding::NOT_PROVIDED) + throw std::runtime_error("amr-nb does not support `encoding` option."); + if (bps != BitDepth::NOT_PROVIDED) + throw std::runtime_error( + "amr-nb does not support `bits_per_sample` option."); + return std::make_tuple<>(SOX_ENCODING_AMR_NB, 16); + case Format::FLAC: + if (enc != Encoding::NOT_PROVIDED) + throw std::runtime_error("flac does not support `encoding` option."); + switch (bps) { + case BitDepth::B32: + case BitDepth::B64: + throw std::runtime_error( + "flac does not support `bits_per_sample` larger than 24."); + default: + return std::make_tuple<>( + SOX_ENCODING_FLAC, static_cast(bps)); + } + case Format::SPHERE: + switch (enc) { + case Encoding::NOT_PROVIDED: + case Encoding::PCM_SIGNED: + switch (bps) { + case BitDepth::NOT_PROVIDED: + return std::make_tuple<>(SOX_ENCODING_SIGN2, 32); + default: + return std::make_tuple<>( + SOX_ENCODING_SIGN2, static_cast(bps)); + } + case Encoding::PCM_UNSIGNED: + throw std::runtime_error( + "sph does not support unsigned integer PCM."); + case Encoding::PCM_FLOAT: + throw std::runtime_error("sph does not support floating point PCM."); + case Encoding::ULAW: + switch (bps) { + case BitDepth::NOT_PROVIDED: + case BitDepth::B8: + return std::make_tuple<>(SOX_ENCODING_ULAW, 8); + default: + throw std::runtime_error( + "sph only supports 8-bit for mu-law encoding."); + } + case Encoding::ALAW: + switch (bps) { + case BitDepth::NOT_PROVIDED: + case BitDepth::B8: + return std::make_tuple<>(SOX_ENCODING_ALAW, 8); + default: + return std::make_tuple<>( + SOX_ENCODING_ALAW, static_cast(bps)); + } + default: + throw std::runtime_error( + "sph does not support encoding: " + encoding.value()); + } + case Format::GSM: + if (enc != Encoding::NOT_PROVIDED) + throw std::runtime_error("gsm does not support `encoding` option."); + if (bps != BitDepth::NOT_PROVIDED) + throw std::runtime_error( + "gsm does not support `bits_per_sample` option."); + return std::make_tuple<>(SOX_ENCODING_GSM, 16); + + default: + throw std::runtime_error("Unsupported format: " + format); + } +} + +unsigned get_precision(const std::string filetype, caffe2::TypeMeta dtype) { + if (filetype == "mp3") + return SOX_UNSPEC; + if (filetype == "flac") + return 24; + if (filetype == "ogg" || filetype == "vorbis") + return SOX_UNSPEC; + if (filetype == "wav" || filetype == "amb") { + switch (dtype.toScalarType()) { + case c10::ScalarType::Byte: + return 8; + case c10::ScalarType::Short: + return 16; + case c10::ScalarType::Int: + return 32; + case c10::ScalarType::Float: + return 32; + default: + throw std::runtime_error("Unsupported dtype."); + } + } + if (filetype == "sph") + return 32; + if (filetype == "amr-nb") { + return 16; + } + if (filetype == "gsm") { + return 16; + } + if (filetype == "htk") { + return 16; + } + throw std::runtime_error("Unsupported file type: " + filetype); +} + +} // namespace + +sox_signalinfo_t get_signalinfo( + const torch::Tensor* waveform, + const int64_t sample_rate, + const std::string filetype, + const bool channels_first) { + return sox_signalinfo_t{ + /*rate=*/static_cast(sample_rate), + /*channels=*/ + static_cast(waveform->size(channels_first ? 0 : 1)), + /*precision=*/get_precision(filetype, waveform->dtype()), + /*length=*/static_cast(waveform->numel())}; +} + +sox_encodinginfo_t get_tensor_encodinginfo(caffe2::TypeMeta dtype) { + sox_encoding_t encoding = [&]() { + switch (dtype.toScalarType()) { + case c10::ScalarType::Byte: + return SOX_ENCODING_UNSIGNED; + case c10::ScalarType::Short: + return SOX_ENCODING_SIGN2; + case c10::ScalarType::Int: + return SOX_ENCODING_SIGN2; + case c10::ScalarType::Float: + return SOX_ENCODING_FLOAT; + default: + throw std::runtime_error("Unsupported dtype."); + } + }(); + unsigned bits_per_sample = [&]() { + switch (dtype.toScalarType()) { + case c10::ScalarType::Byte: + return 8; + case c10::ScalarType::Short: + return 16; + case c10::ScalarType::Int: + return 32; + case c10::ScalarType::Float: + return 32; + default: + throw std::runtime_error("Unsupported dtype."); + } + }(); + return sox_encodinginfo_t{ + /*encoding=*/encoding, + /*bits_per_sample=*/bits_per_sample, + /*compression=*/HUGE_VAL, + /*reverse_bytes=*/sox_option_default, + /*reverse_nibbles=*/sox_option_default, + /*reverse_bits=*/sox_option_default, + /*opposite_endian=*/sox_false}; +} + +sox_encodinginfo_t get_encodinginfo_for_save( + const std::string& format, + const caffe2::TypeMeta dtype, + const c10::optional compression, + const c10::optional encoding, + const c10::optional bits_per_sample) { + auto enc = get_save_encoding(format, dtype, encoding, bits_per_sample); + return sox_encodinginfo_t{ + /*encoding=*/std::get<0>(enc), + /*bits_per_sample=*/std::get<1>(enc), + /*compression=*/compression.value_or(HUGE_VAL), + /*reverse_bytes=*/sox_option_default, + /*reverse_nibbles=*/sox_option_default, + /*reverse_bits=*/sox_option_default, + /*opposite_endian=*/sox_false}; +} + +TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def("torchaudio::sox_utils_set_seed", &torchaudio::sox_utils::set_seed); + m.def( + "torchaudio::sox_utils_set_verbosity", + &torchaudio::sox_utils::set_verbosity); + m.def( + "torchaudio::sox_utils_set_use_threads", + &torchaudio::sox_utils::set_use_threads); + m.def( + "torchaudio::sox_utils_set_buffer_size", + &torchaudio::sox_utils::set_buffer_size); + m.def( + "torchaudio::sox_utils_list_effects", + &torchaudio::sox_utils::list_effects); + m.def( + "torchaudio::sox_utils_list_read_formats", + &torchaudio::sox_utils::list_read_formats); + m.def( + "torchaudio::sox_utils_list_write_formats", + &torchaudio::sox_utils::list_write_formats); + m.def( + "torchaudio::sox_utils_get_buffer_size", + &torchaudio::sox_utils::get_buffer_size); +} + +} // namespace sox_utils +} // namespace torchaudio diff --git a/torchaudio/csrc/sox/utils.h b/torchaudio/csrc/sox/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..73b76e71b92224c58d271dd4f7e73d187cf42475 --- /dev/null +++ b/torchaudio/csrc/sox/utils.h @@ -0,0 +1,118 @@ +#ifndef TORCHAUDIO_SOX_UTILS_H +#define TORCHAUDIO_SOX_UTILS_H + +#include +#include + +namespace torchaudio { +namespace sox_utils { + +//////////////////////////////////////////////////////////////////////////////// +// APIs for Python interaction +//////////////////////////////////////////////////////////////////////////////// + +/// Set sox global options +void set_seed(const int64_t seed); + +void set_verbosity(const int64_t verbosity); + +void set_use_threads(const bool use_threads); + +void set_buffer_size(const int64_t buffer_size); + +int64_t get_buffer_size(); + +std::vector> list_effects(); + +std::vector list_read_formats(); + +std::vector list_write_formats(); + +//////////////////////////////////////////////////////////////////////////////// +// Utilities for sox_io / sox_effects implementations +//////////////////////////////////////////////////////////////////////////////// + +const std::unordered_set UNSUPPORTED_EFFECTS = + {"input", "output", "spectrogram", "noiseprof", "noisered", "splice"}; + +/// helper class to automatically close sox_format_t* +struct SoxFormat { + explicit SoxFormat(sox_format_t* fd) noexcept; + SoxFormat(const SoxFormat& other) = delete; + SoxFormat(SoxFormat&& other) = delete; + SoxFormat& operator=(const SoxFormat& other) = delete; + SoxFormat& operator=(SoxFormat&& other) = delete; + ~SoxFormat(); + sox_format_t* operator->() const noexcept; + operator sox_format_t*() const noexcept; + + void close(); + + private: + sox_format_t* fd_; +}; + +/// +/// Verify that input file is found, has known encoding, and not empty +void validate_input_file(const SoxFormat& sf, const std::string& path); + +/// Verify that input memory buffer has known encoding, and not empty +void validate_input_memfile(const SoxFormat& sf); + +/// +/// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32 +void validate_input_tensor(const torch::Tensor); + +/// +/// Get target dtype for the given encoding and precision. +caffe2::TypeMeta get_dtype( + const sox_encoding_t encoding, + const unsigned precision); + +/// +/// Convert sox_sample_t buffer to uint8/int16/int32/float32 Tensor +/// NOTE: This function might modify the values in the input buffer to +/// reduce the number of memory copy. +/// @param buffer Pointer to buffer that contains audio data. +/// @param num_samples The number of samples to read. +/// @param num_channels The number of channels. Used to reshape the resulting +/// Tensor. +/// @param dtype Target dtype. Determines the output dtype and value range in +/// conjunction with normalization. +/// @param noramlize Perform normalization. Only effective when dtype is not +/// kFloat32. When effective, the output tensor is kFloat32 type and value range +/// is [-1.0, 1.0] +/// @param channels_first When True, output Tensor has shape of [num_channels, +/// num_frames]. +torch::Tensor convert_to_tensor( + sox_sample_t* buffer, + const int32_t num_samples, + const int32_t num_channels, + const caffe2::TypeMeta dtype, + const bool normalize, + const bool channels_first); + +/// Extract extension from file path +const std::string get_filetype(const std::string path); + +/// Get sox_signalinfo_t for passing a torch::Tensor object. +sox_signalinfo_t get_signalinfo( + const torch::Tensor* waveform, + const int64_t sample_rate, + const std::string filetype, + const bool channels_first); + +/// Get sox_encodinginfo_t for Tensor I/O +sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype); + +/// Get sox_encodinginfo_t for saving to file/file object +sox_encodinginfo_t get_encodinginfo_for_save( + const std::string& format, + const caffe2::TypeMeta dtype, + const c10::optional compression, + const c10::optional encoding, + const c10::optional bits_per_sample); + +} // namespace sox_utils +} // namespace torchaudio +#endif diff --git a/torchaudio/csrc/utils.cpp b/torchaudio/csrc/utils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e8050bb69a349534871336fb767f01bf4498e16c --- /dev/null +++ b/torchaudio/csrc/utils.cpp @@ -0,0 +1,30 @@ +#include + +namespace torchaudio { + +namespace { + +bool is_sox_available() { +#ifdef INCLUDE_SOX + return true; +#else + return false; +#endif +} + +bool is_kaldi_available() { +#ifdef INCLUDE_KALDI + return true; +#else + return false; +#endif +} + +} // namespace + +TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def("torchaudio::is_sox_available", &is_sox_available); + m.def("torchaudio::is_kaldi_available", &is_kaldi_available); +} + +} // namespace torchaudio diff --git a/torchaudio/datasets/__init__.py b/torchaudio/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2065444c365c137ab20483ba462b1b972301f81f --- /dev/null +++ b/torchaudio/datasets/__init__.py @@ -0,0 +1,32 @@ +from .commonvoice import COMMONVOICE +from .librispeech import LIBRISPEECH +from .speechcommands import SPEECHCOMMANDS +from .utils import bg_iterator, diskcache_iterator +from .vctk import VCTK, VCTK_092 +from .gtzan import GTZAN +from .yesno import YESNO +from .ljspeech import LJSPEECH +from .cmuarctic import CMUARCTIC +from .cmudict import CMUDict +from .librimix import LibriMix +from .libritts import LIBRITTS +from .tedlium import TEDLIUM + + +__all__ = [ + "COMMONVOICE", + "LIBRISPEECH", + "SPEECHCOMMANDS", + "VCTK", + "VCTK_092", + "YESNO", + "LJSPEECH", + "GTZAN", + "CMUARCTIC", + "CMUDict", + "LibriMix", + "LIBRITTS", + "diskcache_iterator", + "bg_iterator", + "TEDLIUM", +] diff --git a/torchaudio/datasets/cmuarctic.py b/torchaudio/datasets/cmuarctic.py new file mode 100644 index 0000000000000000000000000000000000000000..a01399198b6cb5570686ee6b641e313312cf44a7 --- /dev/null +++ b/torchaudio/datasets/cmuarctic.py @@ -0,0 +1,173 @@ +import os +import csv +from pathlib import Path +from typing import Tuple, Union + +import torchaudio +from torch import Tensor +from torch.utils.data import Dataset +from torchaudio.datasets.utils import ( + download_url, + extract_archive, +) + +URL = "aew" +FOLDER_IN_ARCHIVE = "ARCTIC" +_CHECKSUMS = { + "http://festvox.org/cmu_arctic/packed/cmu_us_aew_arctic.tar.bz2": + "4382b116efcc8339c37e01253cb56295", + "http://festvox.org/cmu_arctic/packed/cmu_us_ahw_arctic.tar.bz2": + "b072d6e961e3f36a2473042d097d6da9", + "http://festvox.org/cmu_arctic/packed/cmu_us_aup_arctic.tar.bz2": + "5301c7aee8919d2abd632e2667adfa7f", + "http://festvox.org/cmu_arctic/packed/cmu_us_awb_arctic.tar.bz2": + "280fdff1e9857119d9a2c57b50e12db7", + "http://festvox.org/cmu_arctic/packed/cmu_us_axb_arctic.tar.bz2": + "5e21cb26c6529c533df1d02ccde5a186", + "http://festvox.org/cmu_arctic/packed/cmu_us_bdl_arctic.tar.bz2": + "b2c3e558f656af2e0a65da0ac0c3377a", + "http://festvox.org/cmu_arctic/packed/cmu_us_clb_arctic.tar.bz2": + "3957c503748e3ce17a3b73c1b9861fb0", + "http://festvox.org/cmu_arctic/packed/cmu_us_eey_arctic.tar.bz2": + "59708e932d27664f9eda3e8e6859969b", + "http://festvox.org/cmu_arctic/packed/cmu_us_fem_arctic.tar.bz2": + "dba4f992ff023347c07c304bf72f4c73", + "http://festvox.org/cmu_arctic/packed/cmu_us_gka_arctic.tar.bz2": + "24a876ea7335c1b0ff21460e1241340f", + "http://festvox.org/cmu_arctic/packed/cmu_us_jmk_arctic.tar.bz2": + "afb69d95f02350537e8a28df5ab6004b", + "http://festvox.org/cmu_arctic/packed/cmu_us_ksp_arctic.tar.bz2": + "4ce5b3b91a0a54b6b685b1b05aa0b3be", + "http://festvox.org/cmu_arctic/packed/cmu_us_ljm_arctic.tar.bz2": + "6f45a3b2c86a4ed0465b353be291f77d", + "http://festvox.org/cmu_arctic/packed/cmu_us_lnh_arctic.tar.bz2": + "c6a15abad5c14d27f4ee856502f0232f", + "http://festvox.org/cmu_arctic/packed/cmu_us_rms_arctic.tar.bz2": + "71072c983df1e590d9e9519e2a621f6e", + "http://festvox.org/cmu_arctic/packed/cmu_us_rxr_arctic.tar.bz2": + "3771ff03a2f5b5c3b53aa0a68b9ad0d5", + "http://festvox.org/cmu_arctic/packed/cmu_us_slp_arctic.tar.bz2": + "9cbf984a832ea01b5058ba9a96862850", + "http://festvox.org/cmu_arctic/packed/cmu_us_slt_arctic.tar.bz2": + "959eecb2cbbc4ac304c6b92269380c81", +} + + +def load_cmuarctic_item(line: str, + path: str, + folder_audio: str, + ext_audio: str) -> Tuple[Tensor, int, str, str]: + + utterance_id, transcript = line[0].strip().split(" ", 2)[1:] + + # Remove space, double quote, and single parenthesis from transcript + transcript = transcript[1:-3] + + file_audio = os.path.join(path, folder_audio, utterance_id + ext_audio) + + # Load audio + waveform, sample_rate = torchaudio.load(file_audio) + + return ( + waveform, + sample_rate, + transcript, + utterance_id.split("_")[1] + ) + + +class CMUARCTIC(Dataset): + """Create a Dataset for CMU_ARCTIC. + + Args: + root (str or Path): Path to the directory where the dataset is found or downloaded. + url (str, optional): + The URL to download the dataset from or the type of the dataset to dowload. + (default: ``"aew"``) + Allowed type values are ``"aew"``, ``"ahw"``, ``"aup"``, ``"awb"``, ``"axb"``, ``"bdl"``, + ``"clb"``, ``"eey"``, ``"fem"``, ``"gka"``, ``"jmk"``, ``"ksp"``, ``"ljm"``, ``"lnh"``, + ``"rms"``, ``"rxr"``, ``"slp"`` or ``"slt"``. + folder_in_archive (str, optional): + The top-level directory of the dataset. (default: ``"ARCTIC"``) + download (bool, optional): + Whether to download the dataset if it is not found at root path. (default: ``False``). + """ + + _file_text = "txt.done.data" + _folder_text = "etc" + _ext_audio = ".wav" + _folder_audio = "wav" + + def __init__(self, + root: Union[str, Path], + url: str = URL, + folder_in_archive: str = FOLDER_IN_ARCHIVE, + download: bool = False) -> None: + + if url in [ + "aew", + "ahw", + "aup", + "awb", + "axb", + "bdl", + "clb", + "eey", + "fem", + "gka", + "jmk", + "ksp", + "ljm", + "lnh", + "rms", + "rxr", + "slp", + "slt" + ]: + + url = "cmu_us_" + url + "_arctic" + ext_archive = ".tar.bz2" + base_url = "http://www.festvox.org/cmu_arctic/packed/" + + url = os.path.join(base_url, url + ext_archive) + + # Get string representation of 'root' in case Path object is passed + root = os.fspath(root) + + basename = os.path.basename(url) + root = os.path.join(root, folder_in_archive) + if not os.path.isdir(root): + os.mkdir(root) + archive = os.path.join(root, basename) + + basename = basename.split(".")[0] + + self._path = os.path.join(root, basename) + + if download: + if not os.path.isdir(self._path): + if not os.path.isfile(archive): + checksum = _CHECKSUMS.get(url, None) + download_url(url, root, hash_value=checksum, hash_type="md5") + extract_archive(archive) + + self._text = os.path.join(self._path, self._folder_text, self._file_text) + + with open(self._text, "r") as text: + walker = csv.reader(text, delimiter="\n") + self._walker = list(walker) + + def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + (Tensor, int, str, str): ``(waveform, sample_rate, transcript, utterance_id)`` + """ + line = self._walker[n] + return load_cmuarctic_item(line, self._path, self._folder_audio, self._ext_audio) + + def __len__(self) -> int: + return len(self._walker) diff --git a/torchaudio/datasets/cmudict.py b/torchaudio/datasets/cmudict.py new file mode 100644 index 0000000000000000000000000000000000000000..2c6acccf210651cc66f8a638d4912846a072dcd4 --- /dev/null +++ b/torchaudio/datasets/cmudict.py @@ -0,0 +1,182 @@ +import os +import re +from pathlib import Path +from typing import Iterable, Tuple, Union, List + +from torch.utils.data import Dataset +from torchaudio.datasets.utils import download_url + +_CHECKSUMS = { + "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b": + "825f4ebd9183f2417df9f067a9cabe86", + "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols": + "385e490aabc71b48e772118e3d02923e", +} +_PUNCTUATIONS = set([ + "!EXCLAMATION-POINT", + "\"CLOSE-QUOTE", + "\"DOUBLE-QUOTE", + "\"END-OF-QUOTE", + "\"END-QUOTE", + "\"IN-QUOTES", + "\"QUOTE", + "\"UNQUOTE", + "#HASH-MARK", + "#POUND-SIGN", + "#SHARP-SIGN", + "%PERCENT", + "&ERSAND", + "'END-INNER-QUOTE", + "'END-QUOTE", + "'INNER-QUOTE", + "'QUOTE", + "'SINGLE-QUOTE", + "(BEGIN-PARENS", + "(IN-PARENTHESES", + "(LEFT-PAREN", + "(OPEN-PARENTHESES", + "(PAREN", + "(PARENS", + "(PARENTHESES", + ")CLOSE-PAREN", + ")CLOSE-PARENTHESES", + ")END-PAREN", + ")END-PARENS", + ")END-PARENTHESES", + ")END-THE-PAREN", + ")PAREN", + ")PARENS", + ")RIGHT-PAREN", + ")UN-PARENTHESES", + "+PLUS", + ",COMMA", + "--DASH", + "-DASH", + "-HYPHEN", + "...ELLIPSIS", + ".DECIMAL", + ".DOT", + ".FULL-STOP", + ".PERIOD", + ".POINT", + "/SLASH", + ":COLON", + ";SEMI-COLON", + ";SEMI-COLON(1)", + "?QUESTION-MARK", + "{BRACE", + "{LEFT-BRACE", + "{OPEN-BRACE", + "}CLOSE-BRACE", + "}RIGHT-BRACE", +]) + + +def _parse_dictionary(lines: Iterable[str], exclude_punctuations: bool) -> List[str]: + _alt_re = re.compile(r'\([0-9]+\)') + cmudict: List[Tuple[str, List[str]]] = list() + for line in lines: + if not line or line.startswith(';;;'): # ignore comments + continue + + word, phones = line.strip().split(' ') + if word in _PUNCTUATIONS: + if exclude_punctuations: + continue + # !EXCLAMATION-POINT -> ! + # --DASH -> -- + # ...ELLIPSIS -> ... + if word.startswith("..."): + word = "..." + elif word.startswith("--"): + word = "--" + else: + word = word[0] + + # if a word have multiple pronunciations, there will be (number) appended to it + # for example, DATAPOINTS and DATAPOINTS(1), + # the regular expression `_alt_re` removes the '(1)' and change the word DATAPOINTS(1) to DATAPOINTS + word = re.sub(_alt_re, '', word) + phones = phones.split(" ") + cmudict.append((word, phones)) + + return cmudict + + +class CMUDict(Dataset): + """Create a Dataset for CMU Pronouncing Dictionary (CMUDict). + + Args: + root (str or Path): Path to the directory where the dataset is found or downloaded. + exclude_punctuations (bool, optional): + When enabled, exclude the pronounciation of punctuations, such as + `!EXCLAMATION-POINT` and `#HASH-MARK`. + download (bool, optional): + Whether to download the dataset if it is not found at root path. (default: ``False``). + url (str, optional): + The URL to download the dictionary from. + (default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b"``) + url_symbols (str, optional): + The URL to download the list of symbols from. + (default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols"``) + """ + + def __init__(self, + root: Union[str, Path], + exclude_punctuations: bool = True, + *, + download: bool = False, + url: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b", + url_symbols: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols", + ) -> None: + + self.exclude_punctuations = exclude_punctuations + + self._root_path = Path(root) + if not os.path.isdir(self._root_path): + raise RuntimeError(f'The root directory does not exist; {root}') + + dict_file = self._root_path / os.path.basename(url) + symbol_file = self._root_path / os.path.basename(url_symbols) + if not os.path.exists(dict_file): + if not download: + raise RuntimeError( + 'The dictionary file is not found in the following location. ' + f'Set `download=True` to download it. {dict_file}') + checksum = _CHECKSUMS.get(url, None) + download_url(url, root, hash_value=checksum, hash_type="md5") + if not os.path.exists(symbol_file): + if not download: + raise RuntimeError( + 'The symbol file is not found in the following location. ' + f'Set `download=True` to download it. {symbol_file}') + checksum = _CHECKSUMS.get(url_symbols, None) + download_url(url_symbols, root, hash_value=checksum, hash_type="md5") + + with open(symbol_file, "r") as text: + self._symbols = [line.strip() for line in text.readlines()] + + with open(dict_file, "r", encoding='latin-1') as text: + self._dictionary = _parse_dictionary( + text.readlines(), exclude_punctuations=self.exclude_punctuations) + + def __getitem__(self, n: int) -> Tuple[str, List[str]]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded. + + Returns: + (str, List[str]): The corresponding word and phonemes ``(word, [phonemes])``. + + """ + return self._dictionary[n] + + def __len__(self) -> int: + return len(self._dictionary) + + @property + def symbols(self) -> List[str]: + """list[str]: A list of phonemes symbols, such as `AA`, `AE`, `AH`. + """ + return self._symbols.copy() diff --git a/torchaudio/datasets/commonvoice.py b/torchaudio/datasets/commonvoice.py new file mode 100644 index 0000000000000000000000000000000000000000..20f9234f89efbb6aeb47e426854871f88b741233 --- /dev/null +++ b/torchaudio/datasets/commonvoice.py @@ -0,0 +1,76 @@ +import csv +import os +from pathlib import Path +from typing import List, Dict, Tuple, Union + +from torch import Tensor +from torch.utils.data import Dataset + +import torchaudio + + +def load_commonvoice_item(line: List[str], + header: List[str], + path: str, + folder_audio: str, + ext_audio: str) -> Tuple[Tensor, int, Dict[str, str]]: + # Each line as the following data: + # client_id, path, sentence, up_votes, down_votes, age, gender, accent + + assert header[1] == "path" + fileid = line[1] + filename = os.path.join(path, folder_audio, fileid) + if not filename.endswith(ext_audio): + filename += ext_audio + waveform, sample_rate = torchaudio.load(filename) + + dic = dict(zip(header, line)) + + return waveform, sample_rate, dic + + +class COMMONVOICE(Dataset): + """Create a Dataset for CommonVoice. + + Args: + root (str or Path): Path to the directory where the dataset is located. + (Where the ``tsv`` file is present.) + tsv (str, optional): + The name of the tsv file used to construct the metadata, such as + ``"train.tsv"``, ``"test.tsv"``, ``"dev.tsv"``, ``"invalidated.tsv"``, + ``"validated.tsv"`` and ``"other.tsv"``. (default: ``"train.tsv"``) + """ + + _ext_txt = ".txt" + _ext_audio = ".mp3" + _folder_audio = "clips" + + def __init__(self, + root: Union[str, Path], + tsv: str = "train.tsv") -> None: + + # Get string representation of 'root' in case Path object is passed + self._path = os.fspath(root) + self._tsv = os.path.join(self._path, tsv) + + with open(self._tsv, "r") as tsv_: + walker = csv.reader(tsv_, delimiter="\t") + self._header = next(walker) + self._walker = list(walker) + + def __getitem__(self, n: int) -> Tuple[Tensor, int, Dict[str, str]]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + (Tensor, int, Dict[str, str]): ``(waveform, sample_rate, dictionary)``, where dictionary + is built from the TSV file with the following keys: ``client_id``, ``path``, ``sentence``, + ``up_votes``, ``down_votes``, ``age``, ``gender`` and ``accent``. + """ + line = self._walker[n] + return load_commonvoice_item(line, self._header, self._path, self._folder_audio, self._ext_audio) + + def __len__(self) -> int: + return len(self._walker) diff --git a/torchaudio/datasets/gtzan.py b/torchaudio/datasets/gtzan.py new file mode 100644 index 0000000000000000000000000000000000000000..b78104cacbdc842c77e857b6a09920cfd0045526 --- /dev/null +++ b/torchaudio/datasets/gtzan.py @@ -0,0 +1,1113 @@ +import os +from pathlib import Path +from typing import Tuple, Optional, Union + +import torchaudio +from torch import Tensor +from torch.utils.data import Dataset +from torchaudio.datasets.utils import ( + download_url, + extract_archive, +) + +# The following lists prefixed with `filtered_` provide a filtered split +# that: +# +# a. Mitigate a known issue with GTZAN (duplication) +# +# b. Provide a standard split for testing it against other +# methods (e.g. the one in jordipons/sklearn-audio-transfer-learning). +# +# Those are used when GTZAN is initialised with the `filtered` keyword. +# The split was taken from (github) jordipons/sklearn-audio-transfer-learning. + +gtzan_genres = [ + "blues", + "classical", + "country", + "disco", + "hiphop", + "jazz", + "metal", + "pop", + "reggae", + "rock", +] + +filtered_test = [ + "blues.00012", + "blues.00013", + "blues.00014", + "blues.00015", + "blues.00016", + "blues.00017", + "blues.00018", + "blues.00019", + "blues.00020", + "blues.00021", + "blues.00022", + "blues.00023", + "blues.00024", + "blues.00025", + "blues.00026", + "blues.00027", + "blues.00028", + "blues.00061", + "blues.00062", + "blues.00063", + "blues.00064", + "blues.00065", + "blues.00066", + "blues.00067", + "blues.00068", + "blues.00069", + "blues.00070", + "blues.00071", + "blues.00072", + "blues.00098", + "blues.00099", + "classical.00011", + "classical.00012", + "classical.00013", + "classical.00014", + "classical.00015", + "classical.00016", + "classical.00017", + "classical.00018", + "classical.00019", + "classical.00020", + "classical.00021", + "classical.00022", + "classical.00023", + "classical.00024", + "classical.00025", + "classical.00026", + "classical.00027", + "classical.00028", + "classical.00029", + "classical.00034", + "classical.00035", + "classical.00036", + "classical.00037", + "classical.00038", + "classical.00039", + "classical.00040", + "classical.00041", + "classical.00049", + "classical.00077", + "classical.00078", + "classical.00079", + "country.00030", + "country.00031", + "country.00032", + "country.00033", + "country.00034", + "country.00035", + "country.00036", + "country.00037", + "country.00038", + "country.00039", + "country.00040", + "country.00043", + "country.00044", + "country.00046", + "country.00047", + "country.00048", + "country.00050", + "country.00051", + "country.00053", + "country.00054", + "country.00055", + "country.00056", + "country.00057", + "country.00058", + "country.00059", + "country.00060", + "country.00061", + "country.00062", + "country.00063", + "country.00064", + "disco.00001", + "disco.00021", + "disco.00058", + "disco.00062", + "disco.00063", + "disco.00064", + "disco.00065", + "disco.00066", + "disco.00069", + "disco.00076", + "disco.00077", + "disco.00078", + "disco.00079", + "disco.00080", + "disco.00081", + "disco.00082", + "disco.00083", + "disco.00084", + "disco.00085", + "disco.00086", + "disco.00087", + "disco.00088", + "disco.00091", + "disco.00092", + "disco.00093", + "disco.00094", + "disco.00096", + "disco.00097", + "disco.00099", + "hiphop.00000", + "hiphop.00026", + "hiphop.00027", + "hiphop.00030", + "hiphop.00040", + "hiphop.00043", + "hiphop.00044", + "hiphop.00045", + "hiphop.00051", + "hiphop.00052", + "hiphop.00053", + "hiphop.00054", + "hiphop.00062", + "hiphop.00063", + "hiphop.00064", + "hiphop.00065", + "hiphop.00066", + "hiphop.00067", + "hiphop.00068", + "hiphop.00069", + "hiphop.00070", + "hiphop.00071", + "hiphop.00072", + "hiphop.00073", + "hiphop.00074", + "hiphop.00075", + "hiphop.00099", + "jazz.00073", + "jazz.00074", + "jazz.00075", + "jazz.00076", + "jazz.00077", + "jazz.00078", + "jazz.00079", + "jazz.00080", + "jazz.00081", + "jazz.00082", + "jazz.00083", + "jazz.00084", + "jazz.00085", + "jazz.00086", + "jazz.00087", + "jazz.00088", + "jazz.00089", + "jazz.00090", + "jazz.00091", + "jazz.00092", + "jazz.00093", + "jazz.00094", + "jazz.00095", + "jazz.00096", + "jazz.00097", + "jazz.00098", + "jazz.00099", + "metal.00012", + "metal.00013", + "metal.00014", + "metal.00015", + "metal.00022", + "metal.00023", + "metal.00025", + "metal.00026", + "metal.00027", + "metal.00028", + "metal.00029", + "metal.00030", + "metal.00031", + "metal.00032", + "metal.00033", + "metal.00038", + "metal.00039", + "metal.00067", + "metal.00070", + "metal.00073", + "metal.00074", + "metal.00075", + "metal.00078", + "metal.00083", + "metal.00085", + "metal.00087", + "metal.00088", + "pop.00000", + "pop.00001", + "pop.00013", + "pop.00014", + "pop.00043", + "pop.00063", + "pop.00064", + "pop.00065", + "pop.00066", + "pop.00069", + "pop.00070", + "pop.00071", + "pop.00072", + "pop.00073", + "pop.00074", + "pop.00075", + "pop.00076", + "pop.00077", + "pop.00078", + "pop.00079", + "pop.00082", + "pop.00088", + "pop.00089", + "pop.00090", + "pop.00091", + "pop.00092", + "pop.00093", + "pop.00094", + "pop.00095", + "pop.00096", + "reggae.00034", + "reggae.00035", + "reggae.00036", + "reggae.00037", + "reggae.00038", + "reggae.00039", + "reggae.00040", + "reggae.00046", + "reggae.00047", + "reggae.00048", + "reggae.00052", + "reggae.00053", + "reggae.00064", + "reggae.00065", + "reggae.00066", + "reggae.00067", + "reggae.00068", + "reggae.00071", + "reggae.00079", + "reggae.00082", + "reggae.00083", + "reggae.00084", + "reggae.00087", + "reggae.00088", + "reggae.00089", + "reggae.00090", + "rock.00010", + "rock.00011", + "rock.00012", + "rock.00013", + "rock.00014", + "rock.00015", + "rock.00027", + "rock.00028", + "rock.00029", + "rock.00030", + "rock.00031", + "rock.00032", + "rock.00033", + "rock.00034", + "rock.00035", + "rock.00036", + "rock.00037", + "rock.00039", + "rock.00040", + "rock.00041", + "rock.00042", + "rock.00043", + "rock.00044", + "rock.00045", + "rock.00046", + "rock.00047", + "rock.00048", + "rock.00086", + "rock.00087", + "rock.00088", + "rock.00089", + "rock.00090", +] + +filtered_train = [ + "blues.00029", + "blues.00030", + "blues.00031", + "blues.00032", + "blues.00033", + "blues.00034", + "blues.00035", + "blues.00036", + "blues.00037", + "blues.00038", + "blues.00039", + "blues.00040", + "blues.00041", + "blues.00042", + "blues.00043", + "blues.00044", + "blues.00045", + "blues.00046", + "blues.00047", + "blues.00048", + "blues.00049", + "blues.00073", + "blues.00074", + "blues.00075", + "blues.00076", + "blues.00077", + "blues.00078", + "blues.00079", + "blues.00080", + "blues.00081", + "blues.00082", + "blues.00083", + "blues.00084", + "blues.00085", + "blues.00086", + "blues.00087", + "blues.00088", + "blues.00089", + "blues.00090", + "blues.00091", + "blues.00092", + "blues.00093", + "blues.00094", + "blues.00095", + "blues.00096", + "blues.00097", + "classical.00030", + "classical.00031", + "classical.00032", + "classical.00033", + "classical.00043", + "classical.00044", + "classical.00045", + "classical.00046", + "classical.00047", + "classical.00048", + "classical.00050", + "classical.00051", + "classical.00052", + "classical.00053", + "classical.00054", + "classical.00055", + "classical.00056", + "classical.00057", + "classical.00058", + "classical.00059", + "classical.00060", + "classical.00061", + "classical.00062", + "classical.00063", + "classical.00064", + "classical.00065", + "classical.00066", + "classical.00067", + "classical.00080", + "classical.00081", + "classical.00082", + "classical.00083", + "classical.00084", + "classical.00085", + "classical.00086", + "classical.00087", + "classical.00088", + "classical.00089", + "classical.00090", + "classical.00091", + "classical.00092", + "classical.00093", + "classical.00094", + "classical.00095", + "classical.00096", + "classical.00097", + "classical.00098", + "classical.00099", + "country.00019", + "country.00020", + "country.00021", + "country.00022", + "country.00023", + "country.00024", + "country.00025", + "country.00026", + "country.00028", + "country.00029", + "country.00065", + "country.00066", + "country.00067", + "country.00068", + "country.00069", + "country.00070", + "country.00071", + "country.00072", + "country.00073", + "country.00074", + "country.00075", + "country.00076", + "country.00077", + "country.00078", + "country.00079", + "country.00080", + "country.00081", + "country.00082", + "country.00083", + "country.00084", + "country.00085", + "country.00086", + "country.00087", + "country.00088", + "country.00089", + "country.00090", + "country.00091", + "country.00092", + "country.00093", + "country.00094", + "country.00095", + "country.00096", + "country.00097", + "country.00098", + "country.00099", + "disco.00005", + "disco.00015", + "disco.00016", + "disco.00017", + "disco.00018", + "disco.00019", + "disco.00020", + "disco.00022", + "disco.00023", + "disco.00024", + "disco.00025", + "disco.00026", + "disco.00027", + "disco.00028", + "disco.00029", + "disco.00030", + "disco.00031", + "disco.00032", + "disco.00033", + "disco.00034", + "disco.00035", + "disco.00036", + "disco.00037", + "disco.00039", + "disco.00040", + "disco.00041", + "disco.00042", + "disco.00043", + "disco.00044", + "disco.00045", + "disco.00047", + "disco.00049", + "disco.00053", + "disco.00054", + "disco.00056", + "disco.00057", + "disco.00059", + "disco.00061", + "disco.00070", + "disco.00073", + "disco.00074", + "disco.00089", + "hiphop.00002", + "hiphop.00003", + "hiphop.00004", + "hiphop.00005", + "hiphop.00006", + "hiphop.00007", + "hiphop.00008", + "hiphop.00009", + "hiphop.00010", + "hiphop.00011", + "hiphop.00012", + "hiphop.00013", + "hiphop.00014", + "hiphop.00015", + "hiphop.00016", + "hiphop.00017", + "hiphop.00018", + "hiphop.00019", + "hiphop.00020", + "hiphop.00021", + "hiphop.00022", + "hiphop.00023", + "hiphop.00024", + "hiphop.00025", + "hiphop.00028", + "hiphop.00029", + "hiphop.00031", + "hiphop.00032", + "hiphop.00033", + "hiphop.00034", + "hiphop.00035", + "hiphop.00036", + "hiphop.00037", + "hiphop.00038", + "hiphop.00041", + "hiphop.00042", + "hiphop.00055", + "hiphop.00056", + "hiphop.00057", + "hiphop.00058", + "hiphop.00059", + "hiphop.00060", + "hiphop.00061", + "hiphop.00077", + "hiphop.00078", + "hiphop.00079", + "hiphop.00080", + "jazz.00000", + "jazz.00001", + "jazz.00011", + "jazz.00012", + "jazz.00013", + "jazz.00014", + "jazz.00015", + "jazz.00016", + "jazz.00017", + "jazz.00018", + "jazz.00019", + "jazz.00020", + "jazz.00021", + "jazz.00022", + "jazz.00023", + "jazz.00024", + "jazz.00041", + "jazz.00047", + "jazz.00048", + "jazz.00049", + "jazz.00050", + "jazz.00051", + "jazz.00052", + "jazz.00053", + "jazz.00054", + "jazz.00055", + "jazz.00056", + "jazz.00057", + "jazz.00058", + "jazz.00059", + "jazz.00060", + "jazz.00061", + "jazz.00062", + "jazz.00063", + "jazz.00064", + "jazz.00065", + "jazz.00066", + "jazz.00067", + "jazz.00068", + "jazz.00069", + "jazz.00070", + "jazz.00071", + "jazz.00072", + "metal.00002", + "metal.00003", + "metal.00005", + "metal.00021", + "metal.00024", + "metal.00035", + "metal.00046", + "metal.00047", + "metal.00048", + "metal.00049", + "metal.00050", + "metal.00051", + "metal.00052", + "metal.00053", + "metal.00054", + "metal.00055", + "metal.00056", + "metal.00057", + "metal.00059", + "metal.00060", + "metal.00061", + "metal.00062", + "metal.00063", + "metal.00064", + "metal.00065", + "metal.00066", + "metal.00069", + "metal.00071", + "metal.00072", + "metal.00079", + "metal.00080", + "metal.00084", + "metal.00086", + "metal.00089", + "metal.00090", + "metal.00091", + "metal.00092", + "metal.00093", + "metal.00094", + "metal.00095", + "metal.00096", + "metal.00097", + "metal.00098", + "metal.00099", + "pop.00002", + "pop.00003", + "pop.00004", + "pop.00005", + "pop.00006", + "pop.00007", + "pop.00008", + "pop.00009", + "pop.00011", + "pop.00012", + "pop.00016", + "pop.00017", + "pop.00018", + "pop.00019", + "pop.00020", + "pop.00023", + "pop.00024", + "pop.00025", + "pop.00026", + "pop.00027", + "pop.00028", + "pop.00029", + "pop.00031", + "pop.00032", + "pop.00033", + "pop.00034", + "pop.00035", + "pop.00036", + "pop.00038", + "pop.00039", + "pop.00040", + "pop.00041", + "pop.00042", + "pop.00044", + "pop.00046", + "pop.00049", + "pop.00050", + "pop.00080", + "pop.00097", + "pop.00098", + "pop.00099", + "reggae.00000", + "reggae.00001", + "reggae.00002", + "reggae.00004", + "reggae.00006", + "reggae.00009", + "reggae.00011", + "reggae.00012", + "reggae.00014", + "reggae.00015", + "reggae.00016", + "reggae.00017", + "reggae.00018", + "reggae.00019", + "reggae.00020", + "reggae.00021", + "reggae.00022", + "reggae.00023", + "reggae.00024", + "reggae.00025", + "reggae.00026", + "reggae.00027", + "reggae.00028", + "reggae.00029", + "reggae.00030", + "reggae.00031", + "reggae.00032", + "reggae.00042", + "reggae.00043", + "reggae.00044", + "reggae.00045", + "reggae.00049", + "reggae.00050", + "reggae.00051", + "reggae.00054", + "reggae.00055", + "reggae.00056", + "reggae.00057", + "reggae.00058", + "reggae.00059", + "reggae.00060", + "reggae.00063", + "reggae.00069", + "rock.00000", + "rock.00001", + "rock.00002", + "rock.00003", + "rock.00004", + "rock.00005", + "rock.00006", + "rock.00007", + "rock.00008", + "rock.00009", + "rock.00016", + "rock.00017", + "rock.00018", + "rock.00019", + "rock.00020", + "rock.00021", + "rock.00022", + "rock.00023", + "rock.00024", + "rock.00025", + "rock.00026", + "rock.00057", + "rock.00058", + "rock.00059", + "rock.00060", + "rock.00061", + "rock.00062", + "rock.00063", + "rock.00064", + "rock.00065", + "rock.00066", + "rock.00067", + "rock.00068", + "rock.00069", + "rock.00070", + "rock.00091", + "rock.00092", + "rock.00093", + "rock.00094", + "rock.00095", + "rock.00096", + "rock.00097", + "rock.00098", + "rock.00099", +] + +filtered_valid = [ + "blues.00000", + "blues.00001", + "blues.00002", + "blues.00003", + "blues.00004", + "blues.00005", + "blues.00006", + "blues.00007", + "blues.00008", + "blues.00009", + "blues.00010", + "blues.00011", + "blues.00050", + "blues.00051", + "blues.00052", + "blues.00053", + "blues.00054", + "blues.00055", + "blues.00056", + "blues.00057", + "blues.00058", + "blues.00059", + "blues.00060", + "classical.00000", + "classical.00001", + "classical.00002", + "classical.00003", + "classical.00004", + "classical.00005", + "classical.00006", + "classical.00007", + "classical.00008", + "classical.00009", + "classical.00010", + "classical.00068", + "classical.00069", + "classical.00070", + "classical.00071", + "classical.00072", + "classical.00073", + "classical.00074", + "classical.00075", + "classical.00076", + "country.00000", + "country.00001", + "country.00002", + "country.00003", + "country.00004", + "country.00005", + "country.00006", + "country.00007", + "country.00009", + "country.00010", + "country.00011", + "country.00012", + "country.00013", + "country.00014", + "country.00015", + "country.00016", + "country.00017", + "country.00018", + "country.00027", + "country.00041", + "country.00042", + "country.00045", + "country.00049", + "disco.00000", + "disco.00002", + "disco.00003", + "disco.00004", + "disco.00006", + "disco.00007", + "disco.00008", + "disco.00009", + "disco.00010", + "disco.00011", + "disco.00012", + "disco.00013", + "disco.00014", + "disco.00046", + "disco.00048", + "disco.00052", + "disco.00067", + "disco.00068", + "disco.00072", + "disco.00075", + "disco.00090", + "disco.00095", + "hiphop.00081", + "hiphop.00082", + "hiphop.00083", + "hiphop.00084", + "hiphop.00085", + "hiphop.00086", + "hiphop.00087", + "hiphop.00088", + "hiphop.00089", + "hiphop.00090", + "hiphop.00091", + "hiphop.00092", + "hiphop.00093", + "hiphop.00094", + "hiphop.00095", + "hiphop.00096", + "hiphop.00097", + "hiphop.00098", + "jazz.00002", + "jazz.00003", + "jazz.00004", + "jazz.00005", + "jazz.00006", + "jazz.00007", + "jazz.00008", + "jazz.00009", + "jazz.00010", + "jazz.00025", + "jazz.00026", + "jazz.00027", + "jazz.00028", + "jazz.00029", + "jazz.00030", + "jazz.00031", + "jazz.00032", + "metal.00000", + "metal.00001", + "metal.00006", + "metal.00007", + "metal.00008", + "metal.00009", + "metal.00010", + "metal.00011", + "metal.00016", + "metal.00017", + "metal.00018", + "metal.00019", + "metal.00020", + "metal.00036", + "metal.00037", + "metal.00068", + "metal.00076", + "metal.00077", + "metal.00081", + "metal.00082", + "pop.00010", + "pop.00053", + "pop.00055", + "pop.00058", + "pop.00059", + "pop.00060", + "pop.00061", + "pop.00062", + "pop.00081", + "pop.00083", + "pop.00084", + "pop.00085", + "pop.00086", + "reggae.00061", + "reggae.00062", + "reggae.00070", + "reggae.00072", + "reggae.00074", + "reggae.00076", + "reggae.00077", + "reggae.00078", + "reggae.00085", + "reggae.00092", + "reggae.00093", + "reggae.00094", + "reggae.00095", + "reggae.00096", + "reggae.00097", + "reggae.00098", + "reggae.00099", + "rock.00038", + "rock.00049", + "rock.00050", + "rock.00051", + "rock.00052", + "rock.00053", + "rock.00054", + "rock.00055", + "rock.00056", + "rock.00071", + "rock.00072", + "rock.00073", + "rock.00074", + "rock.00075", + "rock.00076", + "rock.00077", + "rock.00078", + "rock.00079", + "rock.00080", + "rock.00081", + "rock.00082", + "rock.00083", + "rock.00084", + "rock.00085", +] + + +URL = "http://opihi.cs.uvic.ca/sound/genres.tar.gz" +FOLDER_IN_ARCHIVE = "genres" +_CHECKSUMS = { + "http://opihi.cs.uvic.ca/sound/genres.tar.gz": "5b3d6dddb579ab49814ab86dba69e7c7" +} + + +def load_gtzan_item(fileid: str, path: str, ext_audio: str) -> Tuple[Tensor, str]: + """ + Loads a file from the dataset and returns the raw waveform + as a Torch Tensor, its sample rate as an integer, and its + genre as a string. + """ + # Filenames are of the form label.id, e.g. blues.00078 + label, _ = fileid.split(".") + + # Read wav + file_audio = os.path.join(path, label, fileid + ext_audio) + waveform, sample_rate = torchaudio.load(file_audio) + + return waveform, sample_rate, label + + +class GTZAN(Dataset): + """Create a Dataset for GTZAN. + + Note: + Please see http://marsyas.info/downloads/datasets.html if you are planning to use + this dataset to publish results. + + Args: + root (str or Path): Path to the directory where the dataset is found or downloaded. + url (str, optional): The URL to download the dataset from. + (default: ``"http://opihi.cs.uvic.ca/sound/genres.tar.gz"``) + folder_in_archive (str, optional): The top-level directory of the dataset. + download (bool, optional): + Whether to download the dataset if it is not found at root path. (default: ``False``). + subset (str or None, optional): Which subset of the dataset to use. + One of ``"training"``, ``"validation"``, ``"testing"`` or ``None``. + If ``None``, the entire dataset is used. (default: ``None``). + """ + + _ext_audio = ".wav" + + def __init__( + self, + root: Union[str, Path], + url: str = URL, + folder_in_archive: str = FOLDER_IN_ARCHIVE, + download: bool = False, + subset: Optional[str] = None, + ) -> None: + + # super(GTZAN, self).__init__() + + # Get string representation of 'root' in case Path object is passed + root = os.fspath(root) + + self.root = root + self.url = url + self.folder_in_archive = folder_in_archive + self.download = download + self.subset = subset + + assert subset is None or subset in ["training", "validation", "testing"], ( + "When `subset` not None, it must take a value from " + + "{'training', 'validation', 'testing'}." + ) + + archive = os.path.basename(url) + archive = os.path.join(root, archive) + self._path = os.path.join(root, folder_in_archive) + + if download: + if not os.path.isdir(self._path): + if not os.path.isfile(archive): + checksum = _CHECKSUMS.get(url, None) + download_url(url, root, hash_value=checksum, hash_type="md5") + extract_archive(archive) + + if not os.path.isdir(self._path): + raise RuntimeError( + "Dataset not found. Please use `download=True` to download it." + ) + + if self.subset is None: + # Check every subdirectory under dataset root + # which has the same name as the genres in + # GTZAN (e.g. `root_dir'/blues/, `root_dir'/rock, etc.) + # This lets users remove or move around song files, + # useful when e.g. they want to use only some of the files + # in a genre or want to label other files with a different + # genre. + self._walker = [] + + root = os.path.expanduser(self._path) + + for directory in gtzan_genres: + fulldir = os.path.join(root, directory) + + if not os.path.exists(fulldir): + continue + + songs_in_genre = os.listdir(fulldir) + songs_in_genre.sort() + for fname in songs_in_genre: + name, ext = os.path.splitext(fname) + if ext.lower() == ".wav" and "." in name: + # Check whether the file is of the form + # `gtzan_genre`.`5 digit number`.wav + genre, num = name.split(".") + if genre in gtzan_genres and len(num) == 5 and num.isdigit(): + self._walker.append(name) + else: + if self.subset == "training": + self._walker = filtered_train + elif self.subset == "validation": + self._walker = filtered_valid + elif self.subset == "testing": + self._walker = filtered_test + + def __getitem__(self, n: int) -> Tuple[Tensor, int, str]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + (Tensor, int, str): ``(waveform, sample_rate, label)`` + """ + fileid = self._walker[n] + item = load_gtzan_item(fileid, self._path, self._ext_audio) + waveform, sample_rate, label = item + return waveform, sample_rate, label + + def __len__(self) -> int: + return len(self._walker) diff --git a/torchaudio/datasets/librimix.py b/torchaudio/datasets/librimix.py new file mode 100644 index 0000000000000000000000000000000000000000..ab3b3877dccaeba5b9092fca2b58ad919fd6247f --- /dev/null +++ b/torchaudio/datasets/librimix.py @@ -0,0 +1,89 @@ +from pathlib import Path +from typing import Union, Tuple, List + +import torch +from torch.utils.data import Dataset + +import torchaudio + +SampleType = Tuple[int, torch.Tensor, List[torch.Tensor]] + + +class LibriMix(Dataset): + r"""Create the LibriMix dataset. + + Args: + root (str or Path): The path to the directory where the directory ``Libri2Mix`` or + ``Libri3Mix`` is stored. + subset (str, optional): The subset to use. Options: [``train-360`, ``train-100``, + ``dev``, and ``test``] (Default: ``train-360``). + num_speakers (int, optional): The number of speakers, which determines the directories + to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect + N source audios. (Default: 2) + sample_rate (int, optional): sample rate of audio files. The ``sample_rate`` determines + which subdirectory the audio are fetched. If any of the audio has a different sample + rate, raises ``ValueError``. Options: [8000, 16000] (Default: 8000) + task (str, optional): the task of LibriMix. + Options: [``enh_single``, ``enh_both``, ``sep_clean``, ``sep_noisy``] + (Default: ``sep_clean``) + + Note: + The LibriMix dataset needs to be manually generated. Please check https://github.com/JorisCos/LibriMix + """ + def __init__( + self, + root: Union[str, Path], + subset: str = "train-360", + num_speakers: int = 2, + sample_rate: int = 8000, + task: str = "sep_clean", + ): + self.root = Path(root) / f"Libri{num_speakers}Mix" + if sample_rate == 8000: + self.root = self.root / "wav8k/min" / subset + elif sample_rate == 16000: + self.root = self.root / "wav16k/min" / subset + else: + raise ValueError( + f"Unsupported sample rate. Found {sample_rate}." + ) + self.sample_rate = sample_rate + self.task = task + self.mix_dir = (self.root / f"mix_{task.split('_')[1]}").resolve() + self.src_dirs = [(self.root / f"s{i+1}").resolve() for i in range(num_speakers)] + + self.files = [p.name for p in self.mix_dir.glob("*wav")] + self.files.sort() + + def _load_audio(self, path) -> torch.Tensor: + waveform, sample_rate = torchaudio.load(path) + if sample_rate != self.sample_rate: + raise ValueError( + f"The dataset contains audio file of sample rate {sample_rate}, " + f"but the requested sample rate is {self.sample_rate}." + ) + return waveform + + def _load_sample(self, filename) -> SampleType: + mixed = self._load_audio(str(self.mix_dir / filename)) + srcs = [] + for i, dir_ in enumerate(self.src_dirs): + src = self._load_audio(str(dir_ / filename)) + if mixed.shape != src.shape: + raise ValueError( + f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}" + ) + srcs.append(src) + return self.sample_rate, mixed, srcs + + def __len__(self) -> int: + return len(self.files) + + def __getitem__(self, key: int) -> SampleType: + """Load the n-th sample from the dataset. + Args: + key (int): The index of the sample to be loaded + Returns: + (int, Tensor, List[Tensor]): ``(sample_rate, mix_waveform, list_of_source_waveforms)`` + """ + return self._load_sample(self.files[key]) diff --git a/torchaudio/datasets/librispeech.py b/torchaudio/datasets/librispeech.py new file mode 100644 index 0000000000000000000000000000000000000000..ad8a26493cdff823577f4f1fe8e6ba590303bd25 --- /dev/null +++ b/torchaudio/datasets/librispeech.py @@ -0,0 +1,143 @@ +import os +from typing import Tuple, Union +from pathlib import Path + +import torchaudio +from torch import Tensor +from torch.utils.data import Dataset +from torchaudio.datasets.utils import ( + download_url, + extract_archive, +) + +URL = "train-clean-100" +FOLDER_IN_ARCHIVE = "LibriSpeech" +_CHECKSUMS = { + "http://www.openslr.org/resources/12/dev-clean.tar.gz": + "76f87d090650617fca0cac8f88b9416e0ebf80350acb97b343a85fa903728ab3", + "http://www.openslr.org/resources/12/dev-other.tar.gz": + "12661c48e8c3fe1de2c1caa4c3e135193bfb1811584f11f569dd12645aa84365", + "http://www.openslr.org/resources/12/test-clean.tar.gz": + "39fde525e59672dc6d1551919b1478f724438a95aa55f874b576be21967e6c23", + "http://www.openslr.org/resources/12/test-other.tar.gz": + "d09c181bba5cf717b3dee7d4d592af11a3ee3a09e08ae025c5506f6ebe961c29", + "http://www.openslr.org/resources/12/train-clean-100.tar.gz": + "d4ddd1d5a6ab303066f14971d768ee43278a5f2a0aa43dc716b0e64ecbbbf6e2", + "http://www.openslr.org/resources/12/train-clean-360.tar.gz": + "146a56496217e96c14334a160df97fffedd6e0a04e66b9c5af0d40be3c792ecf", + "http://www.openslr.org/resources/12/train-other-500.tar.gz": + "ddb22f27f96ec163645d53215559df6aa36515f26e01dd70798188350adcb6d2" +} + + +def load_librispeech_item(fileid: str, + path: str, + ext_audio: str, + ext_txt: str) -> Tuple[Tensor, int, str, int, int, int]: + speaker_id, chapter_id, utterance_id = fileid.split("-") + + file_text = speaker_id + "-" + chapter_id + ext_txt + file_text = os.path.join(path, speaker_id, chapter_id, file_text) + + fileid_audio = speaker_id + "-" + chapter_id + "-" + utterance_id + file_audio = fileid_audio + ext_audio + file_audio = os.path.join(path, speaker_id, chapter_id, file_audio) + + # Load audio + waveform, sample_rate = torchaudio.load(file_audio) + + # Load text + with open(file_text) as ft: + for line in ft: + fileid_text, transcript = line.strip().split(" ", 1) + if fileid_audio == fileid_text: + break + else: + # Translation not found + raise FileNotFoundError("Translation not found for " + fileid_audio) + + return ( + waveform, + sample_rate, + transcript, + int(speaker_id), + int(chapter_id), + int(utterance_id), + ) + + +class LIBRISPEECH(Dataset): + """Create a Dataset for LibriSpeech. + + Args: + root (str or Path): Path to the directory where the dataset is found or downloaded. + url (str, optional): The URL to download the dataset from, + or the type of the dataset to dowload. + Allowed type values are ``"dev-clean"``, ``"dev-other"``, ``"test-clean"``, + ``"test-other"``, ``"train-clean-100"``, ``"train-clean-360"`` and + ``"train-other-500"``. (default: ``"train-clean-100"``) + folder_in_archive (str, optional): + The top-level directory of the dataset. (default: ``"LibriSpeech"``) + download (bool, optional): + Whether to download the dataset if it is not found at root path. (default: ``False``). + """ + + _ext_txt = ".trans.txt" + _ext_audio = ".flac" + + def __init__(self, + root: Union[str, Path], + url: str = URL, + folder_in_archive: str = FOLDER_IN_ARCHIVE, + download: bool = False) -> None: + + if url in [ + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + "train-clean-360", + "train-other-500", + ]: + + ext_archive = ".tar.gz" + base_url = "http://www.openslr.org/resources/12/" + + url = os.path.join(base_url, url + ext_archive) + + # Get string representation of 'root' in case Path object is passed + root = os.fspath(root) + + basename = os.path.basename(url) + archive = os.path.join(root, basename) + + basename = basename.split(".")[0] + folder_in_archive = os.path.join(folder_in_archive, basename) + + self._path = os.path.join(root, folder_in_archive) + + if download: + if not os.path.isdir(self._path): + if not os.path.isfile(archive): + checksum = _CHECKSUMS.get(url, None) + download_url(url, root, hash_value=checksum) + extract_archive(archive) + + self._walker = sorted(str(p.stem) for p in Path(self._path).glob('*/*/*' + self._ext_audio)) + + def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + (Tensor, int, str, int, int, int): + ``(waveform, sample_rate, transcript, speaker_id, chapter_id, utterance_id)`` + """ + fileid = self._walker[n] + return load_librispeech_item(fileid, self._path, self._ext_audio, self._ext_txt) + + def __len__(self) -> int: + return len(self._walker) diff --git a/torchaudio/datasets/libritts.py b/torchaudio/datasets/libritts.py new file mode 100644 index 0000000000000000000000000000000000000000..2c978c426ee15546c012b1ad09d19b3e93f47190 --- /dev/null +++ b/torchaudio/datasets/libritts.py @@ -0,0 +1,150 @@ +import os +from typing import Tuple, Union +from pathlib import Path + +import torchaudio +from torch import Tensor +from torch.utils.data import Dataset +from torchaudio.datasets.utils import ( + download_url, + extract_archive, +) + +URL = "train-clean-100" +FOLDER_IN_ARCHIVE = "LibriTTS" +_CHECKSUMS = { + "http://www.openslr.org/60/dev-clean.tar.gz": "0c3076c1e5245bb3f0af7d82087ee207", + "http://www.openslr.org/60/dev-other.tar.gz": "815555d8d75995782ac3ccd7f047213d", + "http://www.openslr.org/60/test-clean.tar.gz": "7bed3bdb047c4c197f1ad3bc412db59f", + "http://www.openslr.org/60/test-other.tar.gz": "ae3258249472a13b5abef2a816f733e4", + "http://www.openslr.org/60/train-clean-100.tar.gz": "4a8c202b78fe1bc0c47916a98f3a2ea8", + "http://www.openslr.org/60/train-clean-360.tar.gz": "a84ef10ddade5fd25df69596a2767b2d", + "http://www.openslr.org/60/train-other-500.tar.gz": "7b181dd5ace343a5f38427999684aa6f", +} + + +def load_libritts_item( + fileid: str, + path: str, + ext_audio: str, + ext_original_txt: str, + ext_normalized_txt: str, +) -> Tuple[Tensor, int, str, str, int, int, str]: + speaker_id, chapter_id, segment_id, utterance_id = fileid.split("_") + utterance_id = fileid + + normalized_text = utterance_id + ext_normalized_txt + normalized_text = os.path.join(path, speaker_id, chapter_id, normalized_text) + + original_text = utterance_id + ext_original_txt + original_text = os.path.join(path, speaker_id, chapter_id, original_text) + + file_audio = utterance_id + ext_audio + file_audio = os.path.join(path, speaker_id, chapter_id, file_audio) + + # Load audio + waveform, sample_rate = torchaudio.load(file_audio) + + # Load original text + with open(original_text) as ft: + original_text = ft.readline() + + # Load normalized text + with open(normalized_text, "r") as ft: + normalized_text = ft.readline() + + return ( + waveform, + sample_rate, + original_text, + normalized_text, + int(speaker_id), + int(chapter_id), + utterance_id, + ) + + +class LIBRITTS(Dataset): + """Create a Dataset for LibriTTS. + + Args: + root (str or Path): Path to the directory where the dataset is found or downloaded. + url (str, optional): The URL to download the dataset from, + or the type of the dataset to dowload. + Allowed type values are ``"dev-clean"``, ``"dev-other"``, ``"test-clean"``, + ``"test-other"``, ``"train-clean-100"``, ``"train-clean-360"`` and + ``"train-other-500"``. (default: ``"train-clean-100"``) + folder_in_archive (str, optional): + The top-level directory of the dataset. (default: ``"LibriTTS"``) + download (bool, optional): + Whether to download the dataset if it is not found at root path. (default: ``False``). + """ + + _ext_original_txt = ".original.txt" + _ext_normalized_txt = ".normalized.txt" + _ext_audio = ".wav" + + def __init__( + self, + root: Union[str, Path], + url: str = URL, + folder_in_archive: str = FOLDER_IN_ARCHIVE, + download: bool = False, + ) -> None: + + if url in [ + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + "train-clean-360", + "train-other-500", + ]: + + ext_archive = ".tar.gz" + base_url = "http://www.openslr.org/resources/60/" + + url = os.path.join(base_url, url + ext_archive) + + # Get string representation of 'root' in case Path object is passed + root = os.fspath(root) + + basename = os.path.basename(url) + archive = os.path.join(root, basename) + + basename = basename.split(".")[0] + folder_in_archive = os.path.join(folder_in_archive, basename) + + self._path = os.path.join(root, folder_in_archive) + + if download: + if not os.path.isdir(self._path): + if not os.path.isfile(archive): + checksum = _CHECKSUMS.get(url, None) + download_url(url, root, hash_value=checksum) + extract_archive(archive) + + self._walker = sorted(str(p.stem) for p in Path(self._path).glob('*/*/*' + self._ext_audio)) + + def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int, int, str]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + (Tensor, int, str, str, str, int, int, str): + ``(waveform, sample_rate, original_text, normalized_text, speaker_id, chapter_id, utterance_id)`` + """ + fileid = self._walker[n] + return load_libritts_item( + fileid, + self._path, + self._ext_audio, + self._ext_original_txt, + self._ext_normalized_txt, + ) + + def __len__(self) -> int: + return len(self._walker) diff --git a/torchaudio/datasets/ljspeech.py b/torchaudio/datasets/ljspeech.py new file mode 100644 index 0000000000000000000000000000000000000000..a0abcbb9ba674ff17cb9b90d007e45731c338d82 --- /dev/null +++ b/torchaudio/datasets/ljspeech.py @@ -0,0 +1,89 @@ +import os +import csv +from typing import Tuple, Union +from pathlib import Path + +import torchaudio +from torchaudio.datasets.utils import download_url, extract_archive +from torch import Tensor +from torch.utils.data import Dataset + +_RELEASE_CONFIGS = { + "release1": { + "folder_in_archive": "wavs", + "url": "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2", + "checksum": "be1a30453f28eb8dd26af4101ae40cbf2c50413b1bb21936cbcdc6fae3de8aa5", + } +} + + +class LJSPEECH(Dataset): + """Create a Dataset for LJSpeech-1.1. + + Args: + root (str or Path): Path to the directory where the dataset is found or downloaded. + url (str, optional): The URL to download the dataset from. + (default: ``"https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"``) + folder_in_archive (str, optional): + The top-level directory of the dataset. (default: ``"wavs"``) + download (bool, optional): + Whether to download the dataset if it is not found at root path. (default: ``False``). + """ + + def __init__(self, + root: Union[str, Path], + url: str = _RELEASE_CONFIGS["release1"]["url"], + folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"], + download: bool = False) -> None: + + self._parse_filesystem(root, url, folder_in_archive, download) + + def _parse_filesystem(self, root: str, url: str, folder_in_archive: str, download: bool) -> None: + root = Path(root) + + basename = os.path.basename(url) + archive = root / basename + + basename = Path(basename.split(".tar.bz2")[0]) + folder_in_archive = basename / folder_in_archive + + self._path = root / folder_in_archive + self._metadata_path = root / basename / 'metadata.csv' + + if download: + if not os.path.isdir(self._path): + if not os.path.isfile(archive): + checksum = _RELEASE_CONFIGS["release1"]["checksum"] + download_url(url, root, hash_value=checksum) + extract_archive(archive) + + with open(self._metadata_path, "r", newline='') as metadata: + flist = csv.reader(metadata, delimiter="|", quoting=csv.QUOTE_NONE) + self._flist = list(flist) + + def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + (Tensor, int, str, str): + ``(waveform, sample_rate, transcript, normalized_transcript)`` + """ + line = self._flist[n] + fileid, transcript, normalized_transcript = line + fileid_audio = self._path / (fileid + ".wav") + + # Load audio + waveform, sample_rate = torchaudio.load(fileid_audio) + + return ( + waveform, + sample_rate, + transcript, + normalized_transcript, + ) + + def __len__(self) -> int: + return len(self._flist) diff --git a/torchaudio/datasets/speechcommands.py b/torchaudio/datasets/speechcommands.py new file mode 100644 index 0000000000000000000000000000000000000000..d92d6d44dff1c58433d7c9783a8228b6b15fb928 --- /dev/null +++ b/torchaudio/datasets/speechcommands.py @@ -0,0 +1,148 @@ +import os +from typing import Tuple, Optional, Union +from pathlib import Path + +import torchaudio +from torch.utils.data import Dataset +from torch import Tensor +from torchaudio.datasets.utils import ( + download_url, + extract_archive, +) + +FOLDER_IN_ARCHIVE = "SpeechCommands" +URL = "speech_commands_v0.02" +HASH_DIVIDER = "_nohash_" +EXCEPT_FOLDER = "_background_noise_" +_CHECKSUMS = { + "https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz": + "3cd23799cb2bbdec517f1cc028f8d43c", + "https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz": + "6b74f3901214cb2c2934e98196829835", +} + + +def _load_list(root, *filenames): + output = [] + for filename in filenames: + filepath = os.path.join(root, filename) + with open(filepath) as fileobj: + output += [os.path.normpath(os.path.join(root, line.strip())) for line in fileobj] + return output + + +def load_speechcommands_item(filepath: str, path: str) -> Tuple[Tensor, int, str, str, int]: + relpath = os.path.relpath(filepath, path) + label, filename = os.path.split(relpath) + # Besides the officially supported split method for datasets defined by "validation_list.txt" + # and "testing_list.txt" over "speech_commands_v0.0x.tar.gz" archives, an alternative split + # method referred to in paragraph 2-3 of Section 7.1, references 13 and 14 of the original + # paper, and the checksums file from the tensorflow_datasets package [1] is also supported. + # Some filenames in those "speech_commands_test_set_v0.0x.tar.gz" archives have the form + # "xxx.wav.wav", so file extensions twice needs to be stripped twice. + # [1] https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/url_checksums/speech_commands.txt + speaker, _ = os.path.splitext(filename) + speaker, _ = os.path.splitext(speaker) + + speaker_id, utterance_number = speaker.split(HASH_DIVIDER) + utterance_number = int(utterance_number) + + # Load audio + waveform, sample_rate = torchaudio.load(filepath) + return waveform, sample_rate, label, speaker_id, utterance_number + + +class SPEECHCOMMANDS(Dataset): + """Create a Dataset for Speech Commands. + + Args: + root (str or Path): Path to the directory where the dataset is found or downloaded. + url (str, optional): The URL to download the dataset from, + or the type of the dataset to dowload. + Allowed type values are ``"speech_commands_v0.01"`` and ``"speech_commands_v0.02"`` + (default: ``"speech_commands_v0.02"``) + folder_in_archive (str, optional): + The top-level directory of the dataset. (default: ``"SpeechCommands"``) + download (bool, optional): + Whether to download the dataset if it is not found at root path. (default: ``False``). + subset (str or None, optional): + Select a subset of the dataset [None, "training", "validation", "testing"]. None means + the whole dataset. "validation" and "testing" are defined in "validation_list.txt" and + "testing_list.txt", respectively, and "training" is the rest. Details for the files + "validation_list.txt" and "testing_list.txt" are explained in the README of the dataset + and in the introduction of Section 7 of the original paper and its reference 12. The + original paper can be found `here `_. (Default: ``None``) + """ + + def __init__(self, + root: Union[str, Path], + url: str = URL, + folder_in_archive: str = FOLDER_IN_ARCHIVE, + download: bool = False, + subset: Optional[str] = None, + ) -> None: + + assert subset is None or subset in ["training", "validation", "testing"], ( + "When `subset` not None, it must take a value from " + + "{'training', 'validation', 'testing'}." + ) + + if url in [ + "speech_commands_v0.01", + "speech_commands_v0.02", + ]: + base_url = "https://storage.googleapis.com/download.tensorflow.org/data/" + ext_archive = ".tar.gz" + + url = os.path.join(base_url, url + ext_archive) + + # Get string representation of 'root' in case Path object is passed + root = os.fspath(root) + + basename = os.path.basename(url) + archive = os.path.join(root, basename) + + basename = basename.rsplit(".", 2)[0] + folder_in_archive = os.path.join(folder_in_archive, basename) + + self._path = os.path.join(root, folder_in_archive) + + if download: + if not os.path.isdir(self._path): + if not os.path.isfile(archive): + checksum = _CHECKSUMS.get(url, None) + download_url(url, root, hash_value=checksum, hash_type="md5") + extract_archive(archive, self._path) + + if subset == "validation": + self._walker = _load_list(self._path, "validation_list.txt") + elif subset == "testing": + self._walker = _load_list(self._path, "testing_list.txt") + elif subset == "training": + excludes = set(_load_list(self._path, "validation_list.txt", "testing_list.txt")) + walker = sorted(str(p) for p in Path(self._path).glob('*/*.wav')) + self._walker = [ + w for w in walker + if HASH_DIVIDER in w + and EXCEPT_FOLDER not in w + and os.path.normpath(w) not in excludes + ] + else: + walker = sorted(str(p) for p in Path(self._path).glob('*/*.wav')) + self._walker = [w for w in walker if HASH_DIVIDER in w and EXCEPT_FOLDER not in w] + + def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + (Tensor, int, str, str, int): + ``(waveform, sample_rate, label, speaker_id, utterance_number)`` + """ + fileid = self._walker[n] + return load_speechcommands_item(fileid, self._path) + + def __len__(self) -> int: + return len(self._walker) diff --git a/torchaudio/datasets/tedlium.py b/torchaudio/datasets/tedlium.py new file mode 100644 index 0000000000000000000000000000000000000000..fe6222a46f17959cc080f7f51f2cfaa504b90183 --- /dev/null +++ b/torchaudio/datasets/tedlium.py @@ -0,0 +1,195 @@ +import os +from typing import Tuple, Union +from pathlib import Path + +import torchaudio +from torch import Tensor +from torch.utils.data import Dataset +from torchaudio.datasets.utils import ( + download_url, + extract_archive, +) + + +_RELEASE_CONFIGS = { + "release1": { + "folder_in_archive": "TEDLIUM_release1", + "url": "http://www.openslr.org/resources/7/TEDLIUM_release1.tar.gz", + "checksum": "30301975fd8c5cac4040c261c0852f57cfa8adbbad2ce78e77e4986957445f27", + "data_path": "", + "subset": "train", + "supported_subsets": ["train", "test", "dev"], + "dict": "TEDLIUM.150K.dic", + }, + "release2": { + "folder_in_archive": "TEDLIUM_release2", + "url": "http://www.openslr.org/resources/19/TEDLIUM_release2.tar.gz", + "checksum": "93281b5fcaaae5c88671c9d000b443cb3c7ea3499ad12010b3934ca41a7b9c58", + "data_path": "", + "subset": "train", + "supported_subsets": ["train", "test", "dev"], + "dict": "TEDLIUM.152k.dic", + }, + "release3": { + "folder_in_archive": "TEDLIUM_release-3", + "url": "http://www.openslr.org/resources/51/TEDLIUM_release-3.tgz", + "checksum": "ad1e454d14d1ad550bc2564c462d87c7a7ec83d4dc2b9210f22ab4973b9eccdb", + "data_path": "data/", + "subset": None, + "supported_subsets": [None], + "dict": "TEDLIUM.152k.dic", + }, +} + + +class TEDLIUM(Dataset): + """ + Create a Dataset for Tedlium. It supports releases 1,2 and 3. + + Args: + root (str or Path): Path to the directory where the dataset is found or downloaded. + release (str, optional): Release version. + Allowed values are ``"release1"``, ``"release2"`` or ``"release3"``. + (default: ``"release1"``). + subset (str, optional): The subset of dataset to use. Valid options are ``"train"``, ``"dev"``, + and ``"test"`` for releases 1&2, ``None`` for release3. Defaults to ``"train"`` or ``None``. + download (bool, optional): + Whether to download the dataset if it is not found at root path. (default: ``False``). + audio_ext (str, optional): extension for audio file (default: ``"audio_ext"``) + """ + def __init__( + self, + root: Union[str, Path], + release: str = "release1", + subset: str = None, + download: bool = False, + audio_ext: str = ".sph" + ) -> None: + self._ext_audio = audio_ext + if release in _RELEASE_CONFIGS.keys(): + folder_in_archive = _RELEASE_CONFIGS[release]["folder_in_archive"] + url = _RELEASE_CONFIGS[release]["url"] + subset = subset if subset else _RELEASE_CONFIGS[release]["subset"] + else: + # Raise warning + raise RuntimeError( + "The release {} does not match any of the supported tedlium releases{} ".format( + release, _RELEASE_CONFIGS.keys(), + ) + ) + if subset not in _RELEASE_CONFIGS[release]["supported_subsets"]: + # Raise warning + raise RuntimeError( + "The subset {} does not match any of the supported tedlium subsets{} ".format( + subset, _RELEASE_CONFIGS[release]["supported_subsets"], + ) + ) + + # Get string representation of 'root' in case Path object is passed + root = os.fspath(root) + + basename = os.path.basename(url) + archive = os.path.join(root, basename) + + basename = basename.split(".")[0] + + self._path = os.path.join(root, folder_in_archive, _RELEASE_CONFIGS[release]["data_path"]) + if subset in ["train", "dev", "test"]: + self._path = os.path.join(self._path, subset) + + if download: + if not os.path.isdir(self._path): + if not os.path.isfile(archive): + checksum = _RELEASE_CONFIGS[release]["checksum"] + download_url(url, root, hash_value=checksum) + extract_archive(archive) + + # Create list for all samples + self._filelist = [] + stm_path = os.path.join(self._path, "stm") + for file in sorted(os.listdir(stm_path)): + if file.endswith(".stm"): + stm_path = os.path.join(self._path, "stm", file) + with open(stm_path) as f: + l = len(f.readlines()) + file = file.replace(".stm", "") + self._filelist.extend((file, line) for line in range(l)) + # Create dict path for later read + self._dict_path = os.path.join(root, folder_in_archive, _RELEASE_CONFIGS[release]["dict"]) + self._phoneme_dict = None + + def _load_tedlium_item(self, fileid: str, line: int, path: str) -> Tuple[Tensor, int, str, int, int, int]: + """Loads a TEDLIUM dataset sample given a file name and corresponding sentence name. + + Args: + fileid (str): File id to identify both text and audio files corresponding to the sample + line (int): Line identifier for the sample inside the text file + path (str): Dataset root path + + Returns: + (Tensor, int, str, int, int, int): + ``(waveform, sample_rate, transcript, talk_id, speaker_id, identifier)`` + """ + transcript_path = os.path.join(path, "stm", fileid) + with open(transcript_path + ".stm") as f: + transcript = f.readlines()[line] + talk_id, _, speaker_id, start_time, end_time, identifier, transcript = transcript.split(" ", 6) + + wave_path = os.path.join(path, "sph", fileid) + waveform, sample_rate = self._load_audio(wave_path + self._ext_audio, start_time=start_time, end_time=end_time) + + return (waveform, sample_rate, transcript, talk_id, speaker_id, identifier) + + def _load_audio(self, path: str, start_time: float, end_time: float, sample_rate: int = 16000) -> [Tensor, int]: + """Default load function used in TEDLIUM dataset, you can overwrite this function to customize functionality + and load individual sentences from a full ted audio talk file. + + Args: + path (str): Path to audio file + start_time (int): Time in seconds where the sample sentence stars + end_time (int): Time in seconds where the sample sentence finishes + sample_rate (float, optional): Sampling rate + + Returns: + [Tensor, int]: Audio tensor representation and sample rate + """ + start_time = int(float(start_time) * sample_rate) + end_time = int(float(end_time) * sample_rate) + + kwargs = {"frame_offset": start_time, "num_frames": end_time - start_time} + + return torchaudio.load(path, **kwargs) + + def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + tuple: ``(waveform, sample_rate, transcript, talk_id, speaker_id, identifier)`` + """ + fileid, line = self._filelist[n] + return self._load_tedlium_item(fileid, line, self._path) + + def __len__(self) -> int: + """TEDLIUM dataset custom function overwritting len default behaviour. + + Returns: + int: TEDLIUM dataset length + """ + return len(self._filelist) + + @property + def phoneme_dict(self): + """dict[str, tuple[str]]: Phonemes. Mapping from word to tuple of phonemes. + Note that some words have empty phonemes. + """ + # Read phoneme dictionary + if not self._phoneme_dict: + self._phoneme_dict = {} + with open(self._dict_path, "r", encoding="utf-8") as f: + for line in f.readlines(): + content = line.strip().split() + self._phoneme_dict[content[0]] = tuple(content[1:]) # content[1:] can be empty list + return self._phoneme_dict.copy() diff --git a/torchaudio/datasets/utils.py b/torchaudio/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e105eb64897aa182481332f1e7cd07edaf8d01b9 --- /dev/null +++ b/torchaudio/datasets/utils.py @@ -0,0 +1,284 @@ +import hashlib +import logging +import os +import tarfile +import threading +import urllib +import urllib.request +import zipfile +from queue import Queue +from typing import Any, Iterable, List, Optional + +import torch +from torch.utils.data import Dataset +from torch.utils.model_zoo import tqdm + +from torchaudio._internal.module_utils import deprecated + + +def stream_url(url: str, + start_byte: Optional[int] = None, + block_size: int = 32 * 1024, + progress_bar: bool = True) -> Iterable: + """Stream url by chunk + + Args: + url (str): Url. + start_byte (int or None, optional): Start streaming at that point (Default: ``None``). + block_size (int, optional): Size of chunks to stream (Default: ``32 * 1024``). + progress_bar (bool, optional): Display a progress bar (Default: ``True``). + """ + + # If we already have the whole file, there is no need to download it again + req = urllib.request.Request(url, method="HEAD") + with urllib.request.urlopen(req) as response: + url_size = int(response.info().get("Content-Length", -1)) + if url_size == start_byte: + return + + req = urllib.request.Request(url) + if start_byte: + req.headers["Range"] = "bytes={}-".format(start_byte) + + with urllib.request.urlopen(req) as upointer, tqdm( + unit="B", + unit_scale=True, + unit_divisor=1024, + total=url_size, + disable=not progress_bar, + ) as pbar: + + num_bytes = 0 + while True: + chunk = upointer.read(block_size) + if not chunk: + break + yield chunk + num_bytes += len(chunk) + pbar.update(len(chunk)) + + +def download_url(url: str, + download_folder: str, + filename: Optional[str] = None, + hash_value: Optional[str] = None, + hash_type: str = "sha256", + progress_bar: bool = True, + resume: bool = False) -> None: + """Download file to disk. + + Args: + url (str): Url. + download_folder (str): Folder to download file. + filename (str or None, optional): Name of downloaded file. If None, it is inferred from the url + (Default: ``None``). + hash_value (str or None, optional): Hash for url (Default: ``None``). + hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``). + progress_bar (bool, optional): Display a progress bar (Default: ``True``). + resume (bool, optional): Enable resuming download (Default: ``False``). + """ + + req = urllib.request.Request(url, method="HEAD") + req_info = urllib.request.urlopen(req).info() + + # Detect filename + filename = filename or req_info.get_filename() or os.path.basename(url) + filepath = os.path.join(download_folder, filename) + if resume and os.path.exists(filepath): + mode = "ab" + local_size: Optional[int] = os.path.getsize(filepath) + + elif not resume and os.path.exists(filepath): + raise RuntimeError( + "{} already exists. Delete the file manually and retry.".format(filepath) + ) + else: + mode = "wb" + local_size = None + + if hash_value and local_size == int(req_info.get("Content-Length", -1)): + with open(filepath, "rb") as file_obj: + if validate_file(file_obj, hash_value, hash_type): + return + raise RuntimeError( + "The hash of {} does not match. Delete the file manually and retry.".format( + filepath + ) + ) + + with open(filepath, mode) as fpointer: + for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar): + fpointer.write(chunk) + + with open(filepath, "rb") as file_obj: + if hash_value and not validate_file(file_obj, hash_value, hash_type): + raise RuntimeError( + "The hash of {} does not match. Delete the file manually and retry.".format( + filepath + ) + ) + + +def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> bool: + """Validate a given file object with its hash. + + Args: + file_obj: File object to read from. + hash_value (str): Hash for url. + hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``). + + Returns: + bool: return True if its a valid file, else False. + """ + + if hash_type == "sha256": + hash_func = hashlib.sha256() + elif hash_type == "md5": + hash_func = hashlib.md5() + else: + raise ValueError + + while True: + # Read by chunk to avoid filling memory + chunk = file_obj.read(1024 ** 2) + if not chunk: + break + hash_func.update(chunk) + + return hash_func.hexdigest() == hash_value + + +def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]: + """Extract archive. + Args: + from_path (str): the path of the archive. + to_path (str or None, optional): the root path of the extraced files (directory of from_path) + (Default: ``None``) + overwrite (bool, optional): overwrite existing files (Default: ``False``) + + Returns: + List[str]: List of paths to extracted files even if not overwritten. + + Examples: + >>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz' + >>> from_path = './validation.tar.gz' + >>> to_path = './' + >>> torchaudio.datasets.utils.download_from_url(url, from_path) + >>> torchaudio.datasets.utils.extract_archive(from_path, to_path) + """ + + if to_path is None: + to_path = os.path.dirname(from_path) + + try: + with tarfile.open(from_path, "r") as tar: + logging.info("Opened tar file {}.".format(from_path)) + files = [] + for file_ in tar: # type: Any + file_path = os.path.join(to_path, file_.name) + if file_.isfile(): + files.append(file_path) + if os.path.exists(file_path): + logging.info("{} already extracted.".format(file_path)) + if not overwrite: + continue + tar.extract(file_, to_path) + return files + except tarfile.ReadError: + pass + + try: + with zipfile.ZipFile(from_path, "r") as zfile: + logging.info("Opened zip file {}.".format(from_path)) + files = zfile.namelist() + for file_ in files: + file_path = os.path.join(to_path, file_) + if os.path.exists(file_path): + logging.info("{} already extracted.".format(file_path)) + if not overwrite: + continue + zfile.extract(file_, to_path) + return files + except zipfile.BadZipFile: + pass + + raise NotImplementedError("We currently only support tar.gz, tgz, and zip achives.") + + +class _DiskCache(Dataset): + """ + Wrap a dataset so that, whenever a new item is returned, it is saved to disk. + """ + + def __init__(self, dataset: Dataset, location: str = ".cached") -> None: + self.dataset = dataset + self.location = location + + self._id = id(self) + self._cache: List = [None] * len(dataset) + + def __getitem__(self, n: int) -> Any: + if self._cache[n]: + f = self._cache[n] + return torch.load(f) + + f = str(self._id) + "-" + str(n) + f = os.path.join(self.location, f) + item = self.dataset[n] + + self._cache[n] = f + os.makedirs(self.location, exist_ok=True) + torch.save(item, f) + + return item + + def __len__(self) -> int: + return len(self.dataset) + + +@deprecated('', version='0.11') +def diskcache_iterator(dataset: Dataset, location: str = ".cached") -> Dataset: + return _DiskCache(dataset, location) + + +class _ThreadedIterator(threading.Thread): + """ + Prefetch the next queue_length items from iterator in a background thread. + + Example: + >> for i in bg_iterator(range(10)): + >> print(i) + """ + + class _End: + pass + + def __init__(self, generator: Iterable, maxsize: int) -> None: + threading.Thread.__init__(self) + self.queue: Queue = Queue(maxsize) + self.generator = generator + self.daemon = True + self.start() + + def run(self) -> None: + for item in self.generator: + self.queue.put(item) + self.queue.put(self._End) + + def __iter__(self) -> Any: + return self + + def __next__(self) -> Any: + next_item = self.queue.get() + if next_item == self._End: + raise StopIteration + return next_item + + # Required for Python 2.7 compatibility + def next(self) -> Any: + return self.__next__() + + +@deprecated('', version='0.11') +def bg_iterator(iterable: Iterable, maxsize: int) -> Any: + return _ThreadedIterator(iterable, maxsize=maxsize) diff --git a/torchaudio/datasets/vctk.py b/torchaudio/datasets/vctk.py new file mode 100644 index 0000000000000000000000000000000000000000..65ec854a082a77222042e8a9551570ce9afa4b3b --- /dev/null +++ b/torchaudio/datasets/vctk.py @@ -0,0 +1,275 @@ +import os +import warnings +from pathlib import Path +from typing import Tuple, Union + +from torch import Tensor +from torch.utils.data import Dataset + +import torchaudio +from torchaudio.datasets.utils import ( + download_url, + extract_archive, +) + +URL = "https://datashare.is.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip" +FOLDER_IN_ARCHIVE = "VCTK-Corpus" +_CHECKSUMS = { + "https://datashare.is.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip": "8a6ba2946b36fcbef0212cad601f4bfa" +} + + +def load_vctk_item(fileid: str, + path: str, + ext_audio: str, + ext_txt: str, + folder_audio: str, + folder_txt: str, + downsample: bool = False) -> Tuple[Tensor, int, str, str, str]: + speaker_id, utterance_id = fileid.split("_") + + # Read text + file_txt = os.path.join(path, folder_txt, speaker_id, fileid + ext_txt) + with open(file_txt) as file_text: + utterance = file_text.readlines()[0] + + # Read wav + file_audio = os.path.join(path, folder_audio, speaker_id, fileid + ext_audio) + waveform, sample_rate = torchaudio.load(file_audio) + if downsample: + # TODO Remove this parameter after deprecation + F = torchaudio.functional + T = torchaudio.transforms + # rate + sample = T.Resample(sample_rate, 16000, resampling_method='sinc_interpolation') + waveform = sample(waveform) + # dither + waveform = F.dither(waveform, noise_shaping=True) + + return waveform, sample_rate, utterance, speaker_id, utterance_id + + +class VCTK(Dataset): + """Create a Dataset for VCTK. + + Note: + * **This dataset is no longer publicly available.** Please use :py:class:`VCTK_092` + * Directory ``p315`` is ignored because there is no corresponding text files. + For more information about the dataset visit: https://datashare.is.ed.ac.uk/handle/10283/3443 + + Args: + root (str or Path): Path to the directory where the dataset is found or downloaded. + url (str, optional): Not used as the dataset is no longer publicly available. + folder_in_archive (str, optional): + The top-level directory of the dataset. (default: ``"VCTK-Corpus"``) + download (bool, optional): + Whether to download the dataset if it is not found at root path. (default: ``False``). + Giving ``download=True`` will result in error as the dataset is no longer + publicly available. + downsample (bool, optional): Not used. + """ + + _folder_txt = "txt" + _folder_audio = "wav48" + _ext_txt = ".txt" + _ext_audio = ".wav" + _except_folder = "p315" + + def __init__(self, + root: Union[str, Path], + url: str = URL, + folder_in_archive: str = FOLDER_IN_ARCHIVE, + download: bool = False, + downsample: bool = False) -> None: + + warnings.warn( + 'VCTK class has been deprecated and will be removed in 0.11 release. ' + 'Please use VCTK_092.' + ) + + if downsample: + warnings.warn( + "In the next version, transforms will not be part of the dataset. " + "Please use `downsample=False` to enable this behavior now, " + "and suppress this warning." + ) + + self.downsample = downsample + # Get string representation of 'root' in case Path object is passed + root = os.fspath(root) + + archive = os.path.basename(url) + archive = os.path.join(root, archive) + self._path = os.path.join(root, folder_in_archive) + + if download: + raise RuntimeError( + "This Dataset is no longer available. " + "Please use `VCTK_092` class to download the latest version." + ) + + if not os.path.isdir(self._path): + raise RuntimeError( + "Dataset not found. Please use `VCTK_092` class " + "with `download=True` to donwload the latest version." + ) + + walker = sorted(str(p.stem) for p in Path(self._path).glob('**/*' + self._ext_audio)) + walker = filter(lambda w: self._except_folder not in w, walker) + self._walker = list(walker) + + def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + tuple: ``(waveform, sample_rate, utterance, speaker_id, utterance_id)`` + """ + fileid = self._walker[n] + item = load_vctk_item( + fileid, + self._path, + self._ext_audio, + self._ext_txt, + self._folder_audio, + self._folder_txt, + ) + + # TODO Upon deprecation, uncomment line below and remove following code + # return item + + waveform, sample_rate, utterance, speaker_id, utterance_id = item + return waveform, sample_rate, utterance, speaker_id, utterance_id + + def __len__(self) -> int: + return len(self._walker) + + +SampleType = Tuple[Tensor, int, str, str, str] + + +class VCTK_092(Dataset): + """Create VCTK 0.92 Dataset + + Args: + root (str): Root directory where the dataset's top level directory is found. + mic_id (str, optional): Microphone ID. Either ``"mic1"`` or ``"mic2"``. (default: ``"mic2"``) + download (bool, optional): + Whether to download the dataset if it is not found at root path. (default: ``False``). + url (str, optional): The URL to download the dataset from. + (default: ``"https://datashare.is.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip"``) + audio_ext (str, optional): Custom audio extension if dataset is converted to non-default audio format. + + Note: + * All the speeches from speaker ``p315`` will be skipped due to the lack of the corresponding text files. + * All the speeches from ``p280`` will be skipped for ``mic_id="mic2"`` due to the lack of the audio files. + * Some of the speeches from speaker ``p362`` will be skipped due to the lack of the audio files. + * See Also: https://datashare.is.ed.ac.uk/handle/10283/3443 + """ + + def __init__( + self, + root: str, + mic_id: str = "mic2", + download: bool = False, + url: str = URL, + audio_ext=".flac", + ): + if mic_id not in ["mic1", "mic2"]: + raise RuntimeError( + f'`mic_id` has to be either "mic1" or "mic2". Found: {mic_id}' + ) + + archive = os.path.join(root, "VCTK-Corpus-0.92.zip") + + self._path = os.path.join(root, "VCTK-Corpus-0.92") + self._txt_dir = os.path.join(self._path, "txt") + self._audio_dir = os.path.join(self._path, "wav48_silence_trimmed") + self._mic_id = mic_id + self._audio_ext = audio_ext + + if download: + if not os.path.isdir(self._path): + if not os.path.isfile(archive): + checksum = _CHECKSUMS.get(url, None) + download_url(url, root, hash_value=checksum, hash_type="md5") + extract_archive(archive, self._path) + + if not os.path.isdir(self._path): + raise RuntimeError( + "Dataset not found. Please use `download=True` to download it." + ) + + # Extracting speaker IDs from the folder structure + self._speaker_ids = sorted(os.listdir(self._txt_dir)) + self._sample_ids = [] + + """ + Due to some insufficient data complexity in the 0.92 version of this dataset, + we start traversing the audio folder structure in accordance with the text folder. + As some of the audio files are missing of either ``mic_1`` or ``mic_2`` but the + text is present for the same, we first check for the existence of the audio file + before adding it to the ``sample_ids`` list. + + Once the ``audio_ids`` are loaded into memory we can quickly access the list for + different parameters required by the user. + """ + for speaker_id in self._speaker_ids: + if speaker_id == "p280" and mic_id == "mic2": + continue + utterance_dir = os.path.join(self._txt_dir, speaker_id) + for utterance_file in sorted( + f for f in os.listdir(utterance_dir) if f.endswith(".txt") + ): + utterance_id = os.path.splitext(utterance_file)[0] + audio_path_mic = os.path.join( + self._audio_dir, + speaker_id, + f"{utterance_id}_{mic_id}{self._audio_ext}", + ) + if speaker_id == "p362" and not os.path.isfile(audio_path_mic): + continue + self._sample_ids.append(utterance_id.split("_")) + + def _load_text(self, file_path) -> str: + with open(file_path) as file_path: + return file_path.readlines()[0] + + def _load_audio(self, file_path) -> Tuple[Tensor, int]: + return torchaudio.load(file_path) + + def _load_sample(self, speaker_id: str, utterance_id: str, mic_id: str) -> SampleType: + transcript_path = os.path.join( + self._txt_dir, speaker_id, f"{speaker_id}_{utterance_id}.txt" + ) + audio_path = os.path.join( + self._audio_dir, + speaker_id, + f"{speaker_id}_{utterance_id}_{mic_id}{self._audio_ext}", + ) + + # Reading text + transcript = self._load_text(transcript_path) + + # Reading FLAC + waveform, sample_rate = self._load_audio(audio_path) + + return (waveform, sample_rate, transcript, speaker_id, utterance_id) + + def __getitem__(self, n: int) -> SampleType: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + (Tensor, int, str, str, str): + ``(waveform, sample_rate, transcript, speaker_id, utterance_id)`` + """ + speaker_id, utterance_id = self._sample_ids[n] + return self._load_sample(speaker_id, utterance_id, self._mic_id) + + def __len__(self) -> int: + return len(self._sample_ids) diff --git a/torchaudio/datasets/yesno.py b/torchaudio/datasets/yesno.py new file mode 100644 index 0000000000000000000000000000000000000000..f33c11852c87dc8c408736f551cc39c3ef05871e --- /dev/null +++ b/torchaudio/datasets/yesno.py @@ -0,0 +1,87 @@ +import os +from pathlib import Path +from typing import List, Tuple, Union + +from torch import Tensor +from torch.utils.data import Dataset + +import torchaudio +from torchaudio.datasets.utils import ( + download_url, + extract_archive, +) + + +_RELEASE_CONFIGS = { + "release1": { + "folder_in_archive": "waves_yesno", + "url": "http://www.openslr.org/resources/1/waves_yesno.tar.gz", + "checksum": "c3f49e0cca421f96b75b41640749167b52118f232498667ca7a5f9416aef8e73", + } +} + + +class YESNO(Dataset): + """Create a Dataset for YesNo. + + Args: + root (str or Path): Path to the directory where the dataset is found or downloaded. + url (str, optional): The URL to download the dataset from. + (default: ``"http://www.openslr.org/resources/1/waves_yesno.tar.gz"``) + folder_in_archive (str, optional): + The top-level directory of the dataset. (default: ``"waves_yesno"``) + download (bool, optional): + Whether to download the dataset if it is not found at root path. (default: ``False``). + """ + + def __init__( + self, + root: Union[str, Path], + url: str = _RELEASE_CONFIGS["release1"]["url"], + folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"], + download: bool = False + ) -> None: + + self._parse_filesystem(root, url, folder_in_archive, download) + + def _parse_filesystem(self, root: str, url: str, folder_in_archive: str, download: bool) -> None: + root = Path(root) + archive = os.path.basename(url) + archive = root / archive + + self._path = root / folder_in_archive + if download: + if not os.path.isdir(self._path): + if not os.path.isfile(archive): + checksum = _RELEASE_CONFIGS["release1"]["checksum"] + download_url(url, root, hash_value=checksum) + extract_archive(archive) + + if not os.path.isdir(self._path): + raise RuntimeError( + "Dataset not found. Please use `download=True` to download it." + ) + + self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*.wav")) + + def _load_item(self, fileid: str, path: str): + labels = [int(c) for c in fileid.split("_")] + file_audio = os.path.join(path, fileid + ".wav") + waveform, sample_rate = torchaudio.load(file_audio) + return waveform, sample_rate, labels + + def __getitem__(self, n: int) -> Tuple[Tensor, int, List[int]]: + """Load the n-th sample from the dataset. + + Args: + n (int): The index of the sample to be loaded + + Returns: + (Tensor, int, List[int]): ``(waveform, sample_rate, labels)`` + """ + fileid = self._walker[n] + item = self._load_item(fileid, self._path) + return item + + def __len__(self) -> int: + return len(self._walker) diff --git a/torchaudio/functional/__init__.py b/torchaudio/functional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc3173cd2184db9217a302ac3546ae018f40dfc4 --- /dev/null +++ b/torchaudio/functional/__init__.py @@ -0,0 +1,105 @@ +from .functional import ( + amplitude_to_DB, + angle, + complex_norm, + compute_deltas, + compute_kaldi_pitch, + create_dct, + create_fb_matrix, + melscale_fbanks, + linear_fbanks, + DB_to_amplitude, + detect_pitch_frequency, + inverse_spectrogram, + griffinlim, + magphase, + mask_along_axis, + mask_along_axis_iid, + mu_law_encoding, + mu_law_decoding, + phase_vocoder, + sliding_window_cmn, + spectrogram, + spectral_centroid, + apply_codec, + resample, + edit_distance, + pitch_shift, + rnnt_loss, +) +from .filtering import ( + allpass_biquad, + band_biquad, + bandpass_biquad, + bandreject_biquad, + bass_biquad, + biquad, + contrast, + dither, + dcshift, + deemph_biquad, + equalizer_biquad, + filtfilt, + flanger, + gain, + highpass_biquad, + lfilter, + lowpass_biquad, + overdrive, + phaser, + riaa_biquad, + treble_biquad, + vad, +) + +__all__ = [ + 'amplitude_to_DB', + 'angle', + 'complex_norm', + 'compute_deltas', + 'compute_kaldi_pitch', + 'create_dct', + 'create_fb_matrix', + 'melscale_fbanks', + 'linear_fbanks', + 'DB_to_amplitude', + 'detect_pitch_frequency', + 'griffinlim', + 'magphase', + 'mask_along_axis', + 'mask_along_axis_iid', + 'mu_law_encoding', + 'mu_law_decoding', + 'phase_vocoder', + 'sliding_window_cmn', + 'spectrogram', + 'inverse_spectrogram', + 'spectral_centroid', + 'allpass_biquad', + 'band_biquad', + 'bandpass_biquad', + 'bandreject_biquad', + 'bass_biquad', + 'biquad', + 'contrast', + 'dither', + 'dcshift', + 'deemph_biquad', + 'equalizer_biquad', + 'filtfilt', + 'flanger', + 'gain', + 'highpass_biquad', + 'lfilter', + 'lowpass_biquad', + 'overdrive', + 'phaser', + 'riaa_biquad', + 'treble_biquad', + 'vad', + 'apply_codec', + 'resample', + 'edit_distance', + 'pitch_shift', + 'rnnt_loss', +] diff --git a/torchaudio/functional/filtering.py b/torchaudio/functional/filtering.py new file mode 100644 index 0000000000000000000000000000000000000000..0f5d24f67ed1afd1fa61ce764d24a6b5e9e6bc9a --- /dev/null +++ b/torchaudio/functional/filtering.py @@ -0,0 +1,1636 @@ +import math +import warnings +from typing import Optional + +import torch +from torch import Tensor + + +def _dB2Linear(x: float) -> float: + return math.exp(x * math.log(10) / 20.0) + + +def _generate_wave_table( + wave_type: str, + data_type: str, + table_size: int, + min: float, + max: float, + phase: float, + device: torch.device, +) -> Tensor: + r"""A helper function for phaser. Generates a table with given parameters. + + Args: + wave_type (str): SINE or TRIANGULAR + data_type (str): desired data_type ( `INT` or `FLOAT` ) + table_size (int): desired table size + min (float): desired min value + max (float): desired max value + phase (float): desired phase + device (torch.device): Torch device on which table must be generated + Returns: + Tensor: A 1D tensor with wave table values + """ + + phase_offset = int(phase / math.pi / 2 * table_size + 0.5) + + t = torch.arange(table_size, device=device, dtype=torch.int32) + + point = (t + phase_offset) % table_size + + d = torch.zeros_like(point, device=device, dtype=torch.float64) + + if wave_type == "SINE": + d = (torch.sin(point.to(torch.float64) / table_size * 2 * math.pi) + 1) / 2 + elif wave_type == "TRIANGLE": + d = point.to(torch.float64) * 2 / table_size + value = torch.div(4 * point, table_size, rounding_mode='floor') + d[value == 0] = d[value == 0] + 0.5 + d[value == 1] = 1.5 - d[value == 1] + d[value == 2] = 1.5 - d[value == 2] + d[value == 3] = d[value == 3] - 1.5 + + d = d * (max - min) + min + + if data_type == "INT": + mask = d < 0 + d[mask] = d[mask] - 0.5 + d[~mask] = d[~mask] + 0.5 + d = d.to(torch.int32) + elif data_type == "FLOAT": + d = d.to(torch.float32) + + return d + + +def allpass_biquad( + waveform: Tensor, sample_rate: int, central_freq: float, Q: float = 0.707 +) -> Tensor: + r"""Design two-pole all-pass filter. Similar to SoX implementation. + + Args: + waveform(torch.Tensor): audio waveform of dimension of `(..., time)` + sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) + central_freq (float or torch.Tensor): central frequency (in Hz) + Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``) + + Returns: + Tensor: Waveform of dimension of `(..., time)` + + Reference: + - http://sox.sourceforge.net/sox.html + - https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF + """ + dtype = waveform.dtype + device = waveform.device + central_freq = torch.as_tensor(central_freq, dtype=dtype, device=device) + Q = torch.as_tensor(Q, dtype=dtype, device=device) + + w0 = 2 * math.pi * central_freq / sample_rate + + alpha = torch.sin(w0) / 2 / Q + + b0 = 1 - alpha + b1 = -2 * torch.cos(w0) + b2 = 1 + alpha + a0 = 1 + alpha + a1 = -2 * torch.cos(w0) + a2 = 1 - alpha + return biquad(waveform, b0, b1, b2, a0, a1, a2) + + +def band_biquad( + waveform: Tensor, + sample_rate: int, + central_freq: float, + Q: float = 0.707, + noise: bool = False, +) -> Tensor: + r"""Design two-pole band filter. Similar to SoX implementation. + + Args: + waveform (Tensor): audio waveform of dimension of `(..., time)` + sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) + central_freq (float or torch.Tensor): central frequency (in Hz) + Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``). + noise (bool, optional) : If ``True``, uses the alternate mode for un-pitched audio (e.g. percussion). + If ``False``, uses mode oriented to pitched audio, i.e. voice, singing, + or instrumental music (Default: ``False``). + + Returns: + Tensor: Waveform of dimension of `(..., time)` + + Reference: + - http://sox.sourceforge.net/sox.html + - https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF + """ + dtype = waveform.dtype + device = waveform.device + central_freq = torch.as_tensor(central_freq, dtype=dtype, device=device) + Q = torch.as_tensor(Q, dtype=dtype, device=device) + + w0 = 2 * math.pi * central_freq / sample_rate + bw_Hz = central_freq / Q + + a0 = 1.0 + a2 = torch.exp(-2 * math.pi * bw_Hz / sample_rate) + a1 = -4 * a2 / (1 + a2) * torch.cos(w0) + + b0 = torch.sqrt(1 - a1 * a1 / (4 * a2)) * (1 - a2) + + if noise: + mult = torch.sqrt(((1 + a2) * (1 + a2) - a1 * a1) * (1 - a2) / (1 + a2)) / b0 + b0 = mult * b0 + + b1 = 0.0 + b2 = 0.0 + + return biquad(waveform, b0, b1, b2, a0, a1, a2) + + +def bandpass_biquad( + waveform: Tensor, + sample_rate: int, + central_freq: float, + Q: float = 0.707, + const_skirt_gain: bool = False, +) -> Tensor: + r"""Design two-pole band-pass filter. Similar to SoX implementation. + + Args: + waveform (Tensor): audio waveform of dimension of `(..., time)` + sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) + central_freq (float or torch.Tensor): central frequency (in Hz) + Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``) + const_skirt_gain (bool, optional) : If ``True``, uses a constant skirt gain (peak gain = Q). + If ``False``, uses a constant 0dB peak gain. (Default: ``False``) + + Returns: + Tensor: Waveform of dimension of `(..., time)` + + Reference: + - http://sox.sourceforge.net/sox.html + - https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF + """ + dtype = waveform.dtype + device = waveform.device + central_freq = torch.as_tensor(central_freq, dtype=dtype, device=device) + Q = torch.as_tensor(Q, dtype=dtype, device=device) + + w0 = 2 * math.pi * central_freq / sample_rate + alpha = torch.sin(w0) / 2 / Q + + temp = torch.sin(w0) / 2 if const_skirt_gain else alpha + b0 = temp + b1 = 0.0 + b2 = -temp + a0 = 1 + alpha + a1 = -2 * torch.cos(w0) + a2 = 1 - alpha + return biquad(waveform, b0, b1, b2, a0, a1, a2) + + +def bandreject_biquad( + waveform: Tensor, sample_rate: int, central_freq: float, Q: float = 0.707 +) -> Tensor: + r"""Design two-pole band-reject filter. Similar to SoX implementation. + + Args: + waveform (Tensor): audio waveform of dimension of `(..., time)` + sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) + central_freq (float or torch.Tensor): central frequency (in Hz) + Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``) + + Returns: + Tensor: Waveform of dimension of `(..., time)` + + Reference: + - http://sox.sourceforge.net/sox.html + - https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF + """ + dtype = waveform.dtype + device = waveform.device + central_freq = torch.as_tensor(central_freq, dtype=dtype, device=device) + Q = torch.as_tensor(Q, dtype=dtype, device=device) + + w0 = 2 * math.pi * central_freq / sample_rate + alpha = torch.sin(w0) / 2 / Q + + b0 = 1.0 + b1 = -2 * torch.cos(w0) + b2 = 1.0 + a0 = 1 + alpha + a1 = -2 * torch.cos(w0) + a2 = 1 - alpha + return biquad(waveform, b0, b1, b2, a0, a1, a2) + + +def bass_biquad( + waveform: Tensor, + sample_rate: int, + gain: float, + central_freq: float = 100, + Q: float = 0.707, +) -> Tensor: + r"""Design a bass tone-control effect. Similar to SoX implementation. + + Args: + waveform (Tensor): audio waveform of dimension of `(..., time)` + sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) + gain (float or torch.Tensor): desired gain at the boost (or attenuation) in dB. + central_freq (float or torch.Tensor, optional): central frequency (in Hz). (Default: ``100``) + Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``). + + Returns: + Tensor: Waveform of dimension of `(..., time)` + + Reference: + - http://sox.sourceforge.net/sox.html + - https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF + """ + dtype = waveform.dtype + device = waveform.device + central_freq = torch.as_tensor(central_freq, dtype=dtype, device=device) + Q = torch.as_tensor(Q, dtype=dtype, device=device) + gain = torch.as_tensor(gain, dtype=dtype, device=device) + + w0 = 2 * math.pi * central_freq / sample_rate + alpha = torch.sin(w0) / 2 / Q + A = torch.exp(gain / 40 * math.log(10)) + + temp1 = 2 * torch.sqrt(A) * alpha + temp2 = (A - 1) * torch.cos(w0) + temp3 = (A + 1) * torch.cos(w0) + + b0 = A * ((A + 1) - temp2 + temp1) + b1 = 2 * A * ((A - 1) - temp3) + b2 = A * ((A + 1) - temp2 - temp1) + a0 = (A + 1) + temp2 + temp1 + a1 = -2 * ((A - 1) + temp3) + a2 = (A + 1) + temp2 - temp1 + + return biquad(waveform, b0 / a0, b1 / a0, b2 / a0, a0 / a0, a1 / a0, a2 / a0) + + +def biquad( + waveform: Tensor, b0: float, b1: float, b2: float, a0: float, a1: float, a2: float +) -> Tensor: + r"""Perform a biquad filter of input tensor. Initial conditions set to 0. + https://en.wikipedia.org/wiki/Digital_biquad_filter + + Args: + waveform (Tensor): audio waveform of dimension of `(..., time)` + b0 (float or torch.Tensor): numerator coefficient of current input, x[n] + b1 (float or torch.Tensor): numerator coefficient of input one time step ago x[n-1] + b2 (float or torch.Tensor): numerator coefficient of input two time steps ago x[n-2] + a0 (float or torch.Tensor): denominator coefficient of current output y[n], typically 1 + a1 (float or torch.Tensor): denominator coefficient of current output y[n-1] + a2 (float or torch.Tensor): denominator coefficient of current output y[n-2] + + Returns: + Tensor: Waveform with dimension of `(..., time)` + """ + + device = waveform.device + dtype = waveform.dtype + + b0 = torch.as_tensor(b0, dtype=dtype, device=device).view(1) + b1 = torch.as_tensor(b1, dtype=dtype, device=device).view(1) + b2 = torch.as_tensor(b2, dtype=dtype, device=device).view(1) + a0 = torch.as_tensor(a0, dtype=dtype, device=device).view(1) + a1 = torch.as_tensor(a1, dtype=dtype, device=device).view(1) + a2 = torch.as_tensor(a2, dtype=dtype, device=device).view(1) + + output_waveform = lfilter( + waveform, + torch.cat([a0, a1, a2]), + torch.cat([b0, b1, b2]), + ) + return output_waveform + + +def contrast(waveform: Tensor, enhancement_amount: float = 75.0) -> Tensor: + r"""Apply contrast effect. Similar to SoX implementation. + Comparable with compression, this effect modifies an audio signal to make it sound louder + + Args: + waveform (Tensor): audio waveform of dimension of `(..., time)` + enhancement_amount (float, optional): controls the amount of the enhancement + Allowed range of values for enhancement_amount : 0-100 + Note that enhancement_amount = 0 still gives a significant contrast enhancement + + Returns: + Tensor: Waveform of dimension of `(..., time)` + + Reference: + - http://sox.sourceforge.net/sox.html + """ + + if not 0 <= enhancement_amount <= 100: + raise ValueError("Allowed range of values for enhancement_amount : 0-100") + + contrast = enhancement_amount / 750.0 + + temp1 = waveform * (math.pi / 2) + temp2 = contrast * torch.sin(temp1 * 4) + output_waveform = torch.sin(temp1 + temp2) + + return output_waveform + + +def dcshift( + waveform: Tensor, shift: float, limiter_gain: Optional[float] = None +) -> Tensor: + r"""Apply a DC shift to the audio. Similar to SoX implementation. + This can be useful to remove a DC offset + (caused perhaps by a hardware problem in the recording chain) from the audio + + Args: + waveform (Tensor): audio waveform of dimension of `(..., time)` + shift (float): indicates the amount to shift the audio + Allowed range of values for shift : -2.0 to +2.0 + limiter_gain (float of None, optional): It is used only on peaks to prevent clipping + It should have a value much less than 1 (e.g. 0.05 or 0.02) + + Returns: + Tensor: Waveform of dimension of `(..., time)` + + Reference: + - http://sox.sourceforge.net/sox.html + """ + output_waveform = waveform + limiter_threshold = 0.0 + + if limiter_gain is not None: + limiter_threshold = 1.0 - (abs(shift) - limiter_gain) + + if limiter_gain is not None and shift > 0: + mask = waveform > limiter_threshold + temp = ( + (waveform[mask] - limiter_threshold) + * limiter_gain + / (1 - limiter_threshold) + ) + output_waveform[mask] = (temp + limiter_threshold + shift).clamp( + max=limiter_threshold + ) + output_waveform[~mask] = (waveform[~mask] + shift).clamp(min=-1, max=1) + elif limiter_gain is not None and shift < 0: + mask = waveform < -limiter_threshold + temp = ( + (waveform[mask] + limiter_threshold) + * limiter_gain + / (1 - limiter_threshold) + ) + output_waveform[mask] = (temp - limiter_threshold + shift).clamp( + min=-limiter_threshold + ) + output_waveform[~mask] = (waveform[~mask] + shift).clamp(min=-1, max=1) + else: + output_waveform = (waveform + shift).clamp(min=-1, max=1) + + return output_waveform + + +def deemph_biquad(waveform: Tensor, sample_rate: int) -> Tensor: + r"""Apply ISO 908 CD de-emphasis (shelving) IIR filter. Similar to SoX implementation. + + Args: + waveform (Tensor): audio waveform of dimension of `(..., time)` + sample_rate (int): sampling rate of the waveform, Allowed sample rate ``44100`` or ``48000`` + + Returns: + Tensor: Waveform of dimension of `(..., time)` + + Reference: + - http://sox.sourceforge.net/sox.html + - https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF + """ + + if sample_rate == 44100: + central_freq = 5283 + width_slope = 0.4845 + gain = -9.477 + elif sample_rate == 48000: + central_freq = 5356 + width_slope = 0.479 + gain = -9.62 + else: + raise ValueError("Sample rate must be 44100 (audio-CD) or 48000 (DAT)") + + w0 = 2 * math.pi * central_freq / sample_rate + A = math.exp(gain / 40.0 * math.log(10)) + alpha = math.sin(w0) / 2 * math.sqrt((A + 1 / A) * (1 / width_slope - 1) + 2) + + temp1 = 2 * math.sqrt(A) * alpha + temp2 = (A - 1) * math.cos(w0) + temp3 = (A + 1) * math.cos(w0) + + b0 = A * ((A + 1) + temp2 + temp1) + b1 = -2 * A * ((A - 1) + temp3) + b2 = A * ((A + 1) + temp2 - temp1) + a0 = (A + 1) - temp2 + temp1 + a1 = 2 * ((A - 1) - temp3) + a2 = (A + 1) - temp2 - temp1 + + return biquad(waveform, b0, b1, b2, a0, a1, a2) + + +def _add_noise_shaping(dithered_waveform: Tensor, waveform: Tensor) -> Tensor: + r"""Noise shaping is calculated by error: + error[n] = dithered[n] - original[n] + noise_shaped_waveform[n] = dithered[n] + error[n-1] + """ + wf_shape = waveform.size() + waveform = waveform.reshape(-1, wf_shape[-1]) + + dithered_shape = dithered_waveform.size() + dithered_waveform = dithered_waveform.reshape(-1, dithered_shape[-1]) + + error = dithered_waveform - waveform + + # add error[n-1] to dithered_waveform[n], so offset the error by 1 index + zeros = torch.zeros(1, dtype=error.dtype, device=error.device) + for index in range(error.size()[0]): + err = error[index] + error_offset = torch.cat((zeros, err)) + error[index] = error_offset[: waveform.size()[1]] + + noise_shaped = dithered_waveform + error + return noise_shaped.reshape(dithered_shape[:-1] + noise_shaped.shape[-1:]) + + +def _apply_probability_distribution( + waveform: Tensor, density_function: str = "TPDF" +) -> Tensor: + r"""Apply a probability distribution function on a waveform. + + Triangular probability density function (TPDF) dither noise has a + triangular distribution; values in the center of the range have a higher + probability of occurring. + + Rectangular probability density function (RPDF) dither noise has a + uniform distribution; any value in the specified range has the same + probability of occurring. + + Gaussian probability density function (GPDF) has a normal distribution. + The relationship of probabilities of results follows a bell-shaped, + or Gaussian curve, typical of dither generated by analog sources. + Args: + waveform (Tensor): Tensor of audio of dimension (..., time) + density_function (str, optional): The density function of a + continuous random variable (Default: ``"TPDF"``) + Options: Triangular Probability Density Function - `TPDF` + Rectangular Probability Density Function - `RPDF` + Gaussian Probability Density Function - `GPDF` + Returns: + Tensor: waveform dithered with TPDF + """ + + # pack batch + shape = waveform.size() + waveform = waveform.reshape(-1, shape[-1]) + + channel_size = waveform.size()[0] - 1 + time_size = waveform.size()[-1] - 1 + + random_channel = ( + int( + torch.randint( + channel_size, + [ + 1, + ], + ).item() + ) + if channel_size > 0 + else 0 + ) + random_time = ( + int( + torch.randint( + time_size, + [ + 1, + ], + ).item() + ) + if time_size > 0 + else 0 + ) + + number_of_bits = 16 + up_scaling = 2 ** (number_of_bits - 1) - 2 + signal_scaled = waveform * up_scaling + down_scaling = 2 ** (number_of_bits - 1) + + signal_scaled_dis = waveform + if density_function == "RPDF": + RPDF = waveform[random_channel][random_time] - 0.5 + + signal_scaled_dis = signal_scaled + RPDF + elif density_function == "GPDF": + # TODO Replace by distribution code once + # https://github.com/pytorch/pytorch/issues/29843 is resolved + # gaussian = torch.distributions.normal.Normal(torch.mean(waveform, -1), 1).sample() + + num_rand_variables = 6 + + gaussian = waveform[random_channel][random_time] + for ws in num_rand_variables * [time_size]: + rand_chan = int( + torch.randint( + channel_size, + [ + 1, + ], + ).item() + ) + gaussian += waveform[rand_chan][ + int( + torch.randint( + ws, + [ + 1, + ], + ).item() + ) + ] + + signal_scaled_dis = signal_scaled + gaussian + else: + # dtype needed for https://github.com/pytorch/pytorch/issues/32358 + TPDF = torch.bartlett_window( + time_size + 1, dtype=signal_scaled.dtype, device=signal_scaled.device + ) + TPDF = TPDF.repeat((channel_size + 1), 1) + signal_scaled_dis = signal_scaled + TPDF + + quantised_signal_scaled = torch.round(signal_scaled_dis) + quantised_signal = quantised_signal_scaled / down_scaling + + # unpack batch + return quantised_signal.reshape(shape[:-1] + quantised_signal.shape[-1:]) + + +def dither( + waveform: Tensor, density_function: str = "TPDF", noise_shaping: bool = False +) -> Tensor: + r"""Dither increases the perceived dynamic range of audio stored at a + particular bit-depth by eliminating nonlinear truncation distortion + (i.e. adding minimally perceived noise to mask distortion caused by quantization). + + Args: + waveform (Tensor): Tensor of audio of dimension (..., time) + density_function (str, optional): + The density function of a continuous random variable. One of + ``"TPDF"`` (Triangular Probability Density Function), + ``"RPDF"`` (Rectangular Probability Density Function) or + ``"GPDF"`` (Gaussian Probability Density Function) (Default: ``"TPDF"``). + noise_shaping (bool, optional): a filtering process that shapes the spectral + energy of quantisation error (Default: ``False``) + + Returns: + Tensor: waveform dithered + """ + dithered = _apply_probability_distribution( + waveform, density_function=density_function + ) + + if noise_shaping: + return _add_noise_shaping(dithered, waveform) + else: + return dithered + + +def equalizer_biquad( + waveform: Tensor, + sample_rate: int, + center_freq: float, + gain: float, + Q: float = 0.707, +) -> Tensor: + r"""Design biquad peaking equalizer filter and perform filtering. Similar to SoX implementation. + + Args: + waveform (Tensor): audio waveform of dimension of `(..., time)` + sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) + center_freq (float): filter's central frequency + gain (float or torch.Tensor): desired gain at the boost (or attenuation) in dB + Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``) + + Returns: + Tensor: Waveform of dimension of `(..., time)` + """ + dtype = waveform.dtype + device = waveform.device + center_freq = torch.as_tensor(center_freq, dtype=dtype, device=device) + Q = torch.as_tensor(Q, dtype=dtype, device=device) + gain = torch.as_tensor(gain, dtype=dtype, device=device) + + w0 = 2 * math.pi * center_freq / sample_rate + A = torch.exp(gain / 40.0 * math.log(10)) + alpha = torch.sin(w0) / 2 / Q + + b0 = 1 + alpha * A + b1 = -2 * torch.cos(w0) + b2 = 1 - alpha * A + a0 = 1 + alpha / A + a1 = -2 * torch.cos(w0) + a2 = 1 - alpha / A + return biquad(waveform, b0, b1, b2, a0, a1, a2) + + +def filtfilt( + waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool = True, +) -> Tensor: + r"""Apply an IIR filter forward and backward to a waveform. + + Inspired by https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.filtfilt.html + + Args: + waveform (Tensor): audio waveform of dimension of `(..., time)`. Must be normalized to -1 to 1. + a_coeffs (Tensor): denominator coefficients of difference equation of dimension of either + 1D with shape `(num_order + 1)` or 2D with shape `(num_filters, num_order + 1)`. + Lower delay coefficients are first, e.g. ``[a0, a1, a2, ...]``. + Must be same size as b_coeffs (pad with 0's as necessary). + b_coeffs (Tensor): numerator coefficients of difference equation of dimension of either + 1D with shape `(num_order + 1)` or 2D with shape `(num_filters, num_order + 1)`. + Lower delay coefficients are first, e.g. ``[b0, b1, b2, ...]``. + Must be same size as a_coeffs (pad with 0's as necessary). + clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``) + + Returns: + Tensor: Waveform with dimension of either `(..., num_filters, time)` if ``a_coeffs`` and ``b_coeffs`` + are 2D Tensors, or `(..., time)` otherwise. + """ + forward_filtered = lfilter(waveform, a_coeffs, b_coeffs, clamp=False, batching=True) + backward_filtered = lfilter( + forward_filtered.flip(-1), a_coeffs, b_coeffs, clamp=clamp, batching=True, + ).flip(-1) + return backward_filtered + + +def flanger( + waveform: Tensor, + sample_rate: int, + delay: float = 0.0, + depth: float = 2.0, + regen: float = 0.0, + width: float = 71.0, + speed: float = 0.5, + phase: float = 25.0, + modulation: str = "sinusoidal", + interpolation: str = "linear", +) -> Tensor: + r"""Apply a flanger effect to the audio. Similar to SoX implementation. + + Args: + waveform (Tensor): audio waveform of dimension of `(..., channel, time)` . + Max 4 channels allowed + sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) + delay (float, optional): desired delay in milliseconds(ms) + Allowed range of values are 0 to 30 + depth (float, optional): desired delay depth in milliseconds(ms) + Allowed range of values are 0 to 10 + regen (float, optional): desired regen(feedback gain) in dB + Allowed range of values are -95 to 95 + width (float, optional): desired width(delay gain) in dB + Allowed range of values are 0 to 100 + speed (float, optional): modulation speed in Hz + Allowed range of values are 0.1 to 10 + phase (float, optional): percentage phase-shift for multi-channel + Allowed range of values are 0 to 100 + modulation (str, optional): Use either "sinusoidal" or "triangular" modulation. (Default: ``sinusoidal``) + interpolation (str, optional): Use either "linear" or "quadratic" for delay-line interpolation. + (Default: ``linear``) + + Returns: + Tensor: Waveform of dimension of `(..., channel, time)` + + Reference: + - http://sox.sourceforge.net/sox.html + + - Scott Lehman, `Effects Explained`_, + + .. _Effects Explained: + https://web.archive.org/web/20051125072557/http://www.harmony-central.com/Effects/effects-explained.html + """ + + if modulation not in ("sinusoidal", "triangular"): + raise ValueError("Only 'sinusoidal' or 'triangular' modulation allowed") + + if interpolation not in ("linear", "quadratic"): + raise ValueError("Only 'linear' or 'quadratic' interpolation allowed") + + actual_shape = waveform.shape + device, dtype = waveform.device, waveform.dtype + + if actual_shape[-2] > 4: + raise ValueError("Max 4 channels allowed") + + # convert to 3D (batch, channels, time) + waveform = waveform.view(-1, actual_shape[-2], actual_shape[-1]) + + # Scaling + feedback_gain = regen / 100 + delay_gain = width / 100 + channel_phase = phase / 100 + delay_min = delay / 1000 + delay_depth = depth / 1000 + + n_channels = waveform.shape[-2] + + if modulation == "sinusoidal": + wave_type = "SINE" + else: + wave_type = "TRIANGLE" + + # Balance output: + in_gain = 1.0 / (1 + delay_gain) + delay_gain = delay_gain / (1 + delay_gain) + + # Balance feedback loop: + delay_gain = delay_gain * (1 - abs(feedback_gain)) + + delay_buf_length = int((delay_min + delay_depth) * sample_rate + 0.5) + delay_buf_length = delay_buf_length + 2 + + delay_bufs = torch.zeros( + waveform.shape[0], n_channels, delay_buf_length, dtype=dtype, device=device + ) + delay_last = torch.zeros(waveform.shape[0], n_channels, dtype=dtype, device=device) + + lfo_length = int(sample_rate / speed) + + table_min = math.floor(delay_min * sample_rate + 0.5) + table_max = delay_buf_length - 2.0 + + lfo = _generate_wave_table( + wave_type=wave_type, + data_type="FLOAT", + table_size=lfo_length, + min=float(table_min), + max=float(table_max), + phase=3 * math.pi / 2, + device=device, + ) + + output_waveform = torch.zeros_like(waveform, dtype=dtype, device=device) + + delay_buf_pos = 0 + lfo_pos = 0 + channel_idxs = torch.arange(0, n_channels, device=device) + + for i in range(waveform.shape[-1]): + + delay_buf_pos = (delay_buf_pos + delay_buf_length - 1) % delay_buf_length + + cur_channel_phase = (channel_idxs * lfo_length * channel_phase + 0.5).to( + torch.int64 + ) + delay_tensor = lfo[(lfo_pos + cur_channel_phase) % lfo_length] + frac_delay = torch.frac(delay_tensor) + delay_tensor = torch.floor(delay_tensor) + + int_delay = delay_tensor.to(torch.int64) + + temp = waveform[:, :, i] + + delay_bufs[:, :, delay_buf_pos] = temp + delay_last * feedback_gain + + delayed_0 = delay_bufs[ + :, channel_idxs, (delay_buf_pos + int_delay) % delay_buf_length + ] + + int_delay = int_delay + 1 + + delayed_1 = delay_bufs[ + :, channel_idxs, (delay_buf_pos + int_delay) % delay_buf_length + ] + + int_delay = int_delay + 1 + + if interpolation == "linear": + delayed = delayed_0 + (delayed_1 - delayed_0) * frac_delay + else: + delayed_2 = delay_bufs[ + :, channel_idxs, (delay_buf_pos + int_delay) % delay_buf_length + ] + + int_delay = int_delay + 1 + + delayed_2 = delayed_2 - delayed_0 + delayed_1 = delayed_1 - delayed_0 + a = delayed_2 * 0.5 - delayed_1 + b = delayed_1 * 2 - delayed_2 * 0.5 + + delayed = delayed_0 + (a * frac_delay + b) * frac_delay + + delay_last = delayed + output_waveform[:, :, i] = waveform[:, :, i] * in_gain + delayed * delay_gain + + lfo_pos = (lfo_pos + 1) % lfo_length + + return output_waveform.clamp(min=-1, max=1).view(actual_shape) + + +def gain(waveform: Tensor, gain_db: float = 1.0) -> Tensor: + r"""Apply amplification or attenuation to the whole waveform. + + Args: + waveform (Tensor): Tensor of audio of dimension (..., time). + gain_db (float, optional) Gain adjustment in decibels (dB) (Default: ``1.0``). + + Returns: + Tensor: the whole waveform amplified by gain_db. + """ + if gain_db == 0: + return waveform + + ratio = 10 ** (gain_db / 20) + + return waveform * ratio + + +def highpass_biquad( + waveform: Tensor, sample_rate: int, cutoff_freq: float, Q: float = 0.707 +) -> Tensor: + r"""Design biquad highpass filter and perform filtering. Similar to SoX implementation. + + Args: + waveform (Tensor): audio waveform of dimension of `(..., time)` + sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) + cutoff_freq (float or torch.Tensor): filter cutoff frequency + Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``) + + Returns: + Tensor: Waveform dimension of `(..., time)` + """ + dtype = waveform.dtype + device = waveform.device + cutoff_freq = torch.as_tensor(cutoff_freq, dtype=dtype, device=device) + Q = torch.as_tensor(Q, dtype=dtype, device=device) + + w0 = 2 * math.pi * cutoff_freq / sample_rate + alpha = torch.sin(w0) / 2.0 / Q + + b0 = (1 + torch.cos(w0)) / 2 + b1 = -1 - torch.cos(w0) + b2 = b0 + a0 = 1 + alpha + a1 = -2 * torch.cos(w0) + a2 = 1 - alpha + return biquad(waveform, b0, b1, b2, a0, a1, a2) + + +def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: Tensor, padded_output_waveform: Tensor): + n_order = a_coeffs_flipped.size(1) + a_coeffs_flipped = a_coeffs_flipped.unsqueeze(2) + for i_sample, o0 in enumerate(input_signal_windows.permute(2, 0, 1)): + windowed_output_signal = padded_output_waveform[ + :, :, i_sample:i_sample + n_order + ] + o0 -= (windowed_output_signal.transpose(0, 1) @ a_coeffs_flipped)[..., 0].t() + padded_output_waveform[:, :, i_sample + n_order - 1] = o0 + + +try: + _lfilter_core_cpu_loop = torch.ops.torchaudio._lfilter_core_loop +except RuntimeError as err: + assert str(err) == 'No such operator torchaudio::_lfilter_core_loop' + _lfilter_core_cpu_loop = _lfilter_core_generic_loop + + +def _lfilter_core( + waveform: Tensor, + a_coeffs: Tensor, + b_coeffs: Tensor, +) -> Tensor: + + assert a_coeffs.size() == b_coeffs.size() + assert len(waveform.size()) == 3 + assert waveform.device == a_coeffs.device + assert b_coeffs.device == a_coeffs.device + + n_batch, n_channel, n_sample = waveform.size() + n_order = a_coeffs.size(1) + assert n_order > 0 + + # Pad the input and create output + + padded_waveform = torch.nn.functional.pad(waveform, [n_order - 1, 0]) + padded_output_waveform = torch.zeros_like(padded_waveform) + + # Set up the coefficients matrix + # Flip coefficients' order + a_coeffs_flipped = a_coeffs.flip(1) + b_coeffs_flipped = b_coeffs.flip(1) + + # calculate windowed_input_signal in parallel using convolution + input_signal_windows = torch.nn.functional.conv1d( + padded_waveform, + b_coeffs_flipped.unsqueeze(1), + groups=n_channel + ) + + input_signal_windows.div_(a_coeffs[:, :1]) + a_coeffs_flipped.div_(a_coeffs[:, :1]) + + if input_signal_windows.device == torch.device('cpu') and\ + a_coeffs_flipped.device == torch.device('cpu') and\ + padded_output_waveform.device == torch.device('cpu'): + _lfilter_core_cpu_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform) + else: + _lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform) + + output = padded_output_waveform[:, :, n_order - 1:] + return output + + +try: + _lfilter = torch.ops.torchaudio._lfilter +except RuntimeError as err: + assert str(err) == 'No such operator torchaudio::_lfilter' + _lfilter = _lfilter_core + + +def lfilter( + waveform: Tensor, + a_coeffs: Tensor, + b_coeffs: Tensor, + clamp: bool = True, + batching: bool = True +) -> Tensor: + r"""Perform an IIR filter by evaluating difference equation. + + Note: + To avoid numerical problems, small filter order is preferred. + Using double precision could also minimize numerical precision errors. + + Args: + waveform (Tensor): audio waveform of dimension of `(..., time)`. Must be normalized to -1 to 1. + a_coeffs (Tensor): denominator coefficients of difference equation of dimension of either + 1D with shape `(num_order + 1)` or 2D with shape `(num_filters, num_order + 1)`. + Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``. + Must be same size as b_coeffs (pad with 0's as necessary). + b_coeffs (Tensor): numerator coefficients of difference equation of dimension of either + 1D with shape `(num_order + 1)` or 2D with shape `(num_filters, num_order + 1)`. + Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``. + Must be same size as a_coeffs (pad with 0's as necessary). + clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``) + batching (bool, optional): Effective only when coefficients are 2D. If ``True``, then waveform should be at + least 2D, and the size of second axis from last should equals to ``num_filters``. + The output can be expressed as ``output[..., i, :] = lfilter(waveform[..., i, :], + a_coeffs[i], b_coeffs[i], clamp=clamp, batching=False)``. (Default: ``True``) + + Returns: + Tensor: Waveform with dimension of either `(..., num_filters, time)` if ``a_coeffs`` and ``b_coeffs`` + are 2D Tensors, or `(..., time)` otherwise. + """ + assert a_coeffs.size() == b_coeffs.size() + assert a_coeffs.ndim <= 2 + + if a_coeffs.ndim > 1: + if batching: + assert waveform.ndim > 1 + assert waveform.shape[-2] == a_coeffs.shape[0] + else: + waveform = torch.stack([waveform] * a_coeffs.shape[0], -2) + else: + a_coeffs = a_coeffs.unsqueeze(0) + b_coeffs = b_coeffs.unsqueeze(0) + + # pack batch + shape = waveform.size() + waveform = waveform.reshape(-1, a_coeffs.shape[0], shape[-1]) + output = _lfilter(waveform, a_coeffs, b_coeffs) + + if clamp: + output = torch.clamp(output, min=-1.0, max=1.0) + + # unpack batch + output = output.reshape(shape[:-1] + output.shape[-1:]) + + return output + + +def lowpass_biquad( + waveform: Tensor, sample_rate: int, cutoff_freq: float, Q: float = 0.707 +) -> Tensor: + r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation. + + Args: + waveform (torch.Tensor): audio waveform of dimension of `(..., time)` + sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) + cutoff_freq (float or torch.Tensor): filter cutoff frequency + Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``) + + Returns: + Tensor: Waveform of dimension of `(..., time)` + """ + dtype = waveform.dtype + device = waveform.device + cutoff_freq = torch.as_tensor(cutoff_freq, dtype=dtype, device=device) + Q = torch.as_tensor(Q, dtype=dtype, device=device) + + w0 = 2 * math.pi * cutoff_freq / sample_rate + alpha = torch.sin(w0) / 2 / Q + + b0 = (1 - torch.cos(w0)) / 2 + b1 = 1 - torch.cos(w0) + b2 = b0 + a0 = 1 + alpha + a1 = -2 * torch.cos(w0) + a2 = 1 - alpha + return biquad(waveform, b0, b1, b2, a0, a1, a2) + + +def _overdrive_core_loop_generic( + waveform: Tensor, + temp: Tensor, + last_in: Tensor, + last_out: Tensor, + output_waveform: Tensor +): + for i in range(waveform.shape[-1]): + last_out = temp[:, i] - last_in + 0.995 * last_out + last_in = temp[:, i] + output_waveform[:, i] = waveform[:, i] * 0.5 + last_out * 0.75 + + +try: + _overdrive_core_loop_cpu = torch.ops.torchaudio._overdrive_core_loop +except RuntimeError as err: + assert str(err) == 'No such operator torchaudio::_overdrive_core_loop' + _overdrive_core_loop_cpu = _overdrive_core_loop_generic + + +def overdrive(waveform: Tensor, gain: float = 20, colour: float = 20) -> Tensor: + r"""Apply a overdrive effect to the audio. Similar to SoX implementation. + This effect applies a non linear distortion to the audio signal. + + Args: + waveform (Tensor): audio waveform of dimension of `(..., time)` + gain (float, optional): desired gain at the boost (or attenuation) in dB + Allowed range of values are 0 to 100 + colour (float, optional): controls the amount of even harmonic content in the over-driven output + Allowed range of values are 0 to 100 + + Returns: + Tensor: Waveform of dimension of `(..., time)` + + Reference: + - http://sox.sourceforge.net/sox.html + """ + actual_shape = waveform.shape + device, dtype = waveform.device, waveform.dtype + + # convert to 2D (..,time) + waveform = waveform.view(-1, actual_shape[-1]) + + gain = _dB2Linear(gain) + colour = colour / 200 + last_in = torch.zeros(waveform.shape[:-1], dtype=dtype, device=device) + last_out = torch.zeros(waveform.shape[:-1], dtype=dtype, device=device) + + temp = waveform * gain + colour + + mask1 = temp < -1 + temp[mask1] = torch.tensor(-2.0 / 3.0, dtype=dtype, device=device) + # Wrapping the constant with Tensor is required for Torchscript + + mask2 = temp > 1 + temp[mask2] = torch.tensor(2.0 / 3.0, dtype=dtype, device=device) + + mask3 = ~mask1 & ~mask2 + temp[mask3] = temp[mask3] - (temp[mask3] ** 3) * (1.0 / 3) + + output_waveform = torch.zeros_like(waveform, dtype=dtype, device=device) + + # Uses CPU optimized loop function if available for CPU device + if device == torch.device('cpu'): + _overdrive_core_loop_cpu(waveform, temp, last_in, last_out, output_waveform) + else: + _overdrive_core_loop_generic(waveform, temp, last_in, last_out, output_waveform) + + return output_waveform.clamp(min=-1, max=1).view(actual_shape) + + +def phaser( + waveform: Tensor, + sample_rate: int, + gain_in: float = 0.4, + gain_out: float = 0.74, + delay_ms: float = 3.0, + decay: float = 0.4, + mod_speed: float = 0.5, + sinusoidal: bool = True, +) -> Tensor: + r"""Apply a phasing effect to the audio. Similar to SoX implementation. + + Args: + waveform (Tensor): audio waveform of dimension of `(..., time)` + sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) + gain_in (float, optional): desired input gain at the boost (or attenuation) in dB + Allowed range of values are 0 to 1 + gain_out (float, optional): desired output gain at the boost (or attenuation) in dB + Allowed range of values are 0 to 1e9 + delay_ms (float, optional): desired delay in milliseconds + Allowed range of values are 0 to 5.0 + decay (float, optional): desired decay relative to gain-in + Allowed range of values are 0 to 0.99 + mod_speed (float, optional): modulation speed in Hz + Allowed range of values are 0.1 to 2 + sinusoidal (bool, optional): If ``True``, uses sinusoidal modulation (preferable for multiple instruments) + If ``False``, uses triangular modulation (gives single instruments a sharper phasing effect) + (Default: ``True``) + + Returns: + Tensor: Waveform of dimension of `(..., time)` + + Reference: + - http://sox.sourceforge.net/sox.html + - Scott Lehman, `Effects Explained`_. + + .. _Effects Explained: + https://web.archive.org/web/20051125072557/http://www.harmony-central.com/Effects/effects-explained.html + """ + actual_shape = waveform.shape + device, dtype = waveform.device, waveform.dtype + + # convert to 2D (channels,time) + waveform = waveform.view(-1, actual_shape[-1]) + + delay_buf_len = int((delay_ms * 0.001 * sample_rate) + 0.5) + delay_buf = torch.zeros( + waveform.shape[0], delay_buf_len, dtype=dtype, device=device + ) + + mod_buf_len = int(sample_rate / mod_speed + 0.5) + + if sinusoidal: + wave_type = "SINE" + else: + wave_type = "TRIANGLE" + + mod_buf = _generate_wave_table( + wave_type=wave_type, + data_type="INT", + table_size=mod_buf_len, + min=1.0, + max=float(delay_buf_len), + phase=math.pi / 2, + device=device, + ) + + delay_pos = 0 + mod_pos = 0 + + output_waveform_pre_gain_list = [] + waveform = waveform * gain_in + delay_buf = delay_buf * decay + waveform_list = [waveform[:, i] for i in range(waveform.size(1))] + delay_buf_list = [delay_buf[:, i] for i in range(delay_buf.size(1))] + mod_buf_list = [mod_buf[i] for i in range(mod_buf.size(0))] + + for i in range(waveform.shape[-1]): + idx = int((delay_pos + mod_buf_list[mod_pos]) % delay_buf_len) + mod_pos = (mod_pos + 1) % mod_buf_len + delay_pos = (delay_pos + 1) % delay_buf_len + temp = (waveform_list[i]) + (delay_buf_list[idx]) + delay_buf_list[delay_pos] = temp * decay + output_waveform_pre_gain_list.append(temp) + + output_waveform = torch.stack(output_waveform_pre_gain_list, dim=1).to( + dtype=dtype, device=device + ) + output_waveform.mul_(gain_out) + + return output_waveform.clamp(min=-1, max=1).view(actual_shape) + + +def riaa_biquad(waveform: Tensor, sample_rate: int) -> Tensor: + r"""Apply RIAA vinyl playback equalization. Similar to SoX implementation. + + Args: + waveform (Tensor): audio waveform of dimension of `(..., time)` + sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz). + Allowed sample rates in Hz : ``44100``,``48000``,``88200``,``96000`` + + Returns: + Tensor: Waveform of dimension of `(..., time)` + + Reference: + - http://sox.sourceforge.net/sox.html + - https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF + """ + + if sample_rate == 44100: + zeros = [-0.2014898, 0.9233820] + poles = [0.7083149, 0.9924091] + + elif sample_rate == 48000: + zeros = [-0.1766069, 0.9321590] + poles = [0.7396325, 0.9931330] + + elif sample_rate == 88200: + zeros = [-0.1168735, 0.9648312] + poles = [0.8590646, 0.9964002] + + elif sample_rate == 96000: + zeros = [-0.1141486, 0.9676817] + poles = [0.8699137, 0.9966946] + + else: + raise ValueError("Sample rate must be 44.1k, 48k, 88.2k, or 96k") + + # polynomial coefficients with roots zeros[0] and zeros[1] + b0 = 1.0 + b1 = -(zeros[0] + zeros[1]) + b2 = zeros[0] * zeros[1] + + # polynomial coefficients with roots poles[0] and poles[1] + a0 = 1.0 + a1 = -(poles[0] + poles[1]) + a2 = poles[0] * poles[1] + + # Normalize to 0dB at 1kHz + y = 2 * math.pi * 1000 / sample_rate + b_re = b0 + b1 * math.cos(-y) + b2 * math.cos(-2 * y) + a_re = a0 + a1 * math.cos(-y) + a2 * math.cos(-2 * y) + b_im = b1 * math.sin(-y) + b2 * math.sin(-2 * y) + a_im = a1 * math.sin(-y) + a2 * math.sin(-2 * y) + g = 1 / math.sqrt((b_re ** 2 + b_im ** 2) / (a_re ** 2 + a_im ** 2)) + + b0 *= g + b1 *= g + b2 *= g + + return biquad(waveform, b0, b1, b2, a0, a1, a2) + + +def treble_biquad( + waveform: Tensor, + sample_rate: int, + gain: float, + central_freq: float = 3000, + Q: float = 0.707, +) -> Tensor: + r"""Design a treble tone-control effect. Similar to SoX implementation. + + Args: + waveform (Tensor): audio waveform of dimension of `(..., time)` + sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) + gain (float or torch.Tensor): desired gain at the boost (or attenuation) in dB. + central_freq (float or torch.Tensor, optional): central frequency (in Hz). (Default: ``3000``) + Q (float or torch.Tensor, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``). + + Returns: + Tensor: Waveform of dimension of `(..., time)` + + Reference: + - http://sox.sourceforge.net/sox.html + - https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF + """ + dtype = waveform.dtype + device = waveform.device + central_freq = torch.as_tensor(central_freq, dtype=dtype, device=device) + Q = torch.as_tensor(Q, dtype=dtype, device=device) + gain = torch.as_tensor(gain, dtype=dtype, device=device) + + w0 = 2 * math.pi * central_freq / sample_rate + alpha = torch.sin(w0) / 2 / Q + A = torch.exp(gain / 40 * math.log(10)) + + temp1 = 2 * torch.sqrt(A) * alpha + temp2 = (A - 1) * torch.cos(w0) + temp3 = (A + 1) * torch.cos(w0) + + b0 = A * ((A + 1) + temp2 + temp1) + b1 = -2 * A * ((A - 1) + temp3) + b2 = A * ((A + 1) + temp2 - temp1) + a0 = (A + 1) - temp2 + temp1 + a1 = 2 * ((A - 1) - temp3) + a2 = (A + 1) - temp2 - temp1 + + return biquad(waveform, b0, b1, b2, a0, a1, a2) + + +def _measure( + measure_len_ws: int, + samples: Tensor, + spectrum: Tensor, + noise_spectrum: Tensor, + spectrum_window: Tensor, + spectrum_start: int, + spectrum_end: int, + cepstrum_window: Tensor, + cepstrum_start: int, + cepstrum_end: int, + noise_reduction_amount: float, + measure_smooth_time_mult: float, + noise_up_time_mult: float, + noise_down_time_mult: float, + index_ns: int, + boot_count: int, +) -> float: + + assert spectrum.size()[-1] == noise_spectrum.size()[-1] + + samplesLen_ns = samples.size()[-1] + dft_len_ws = spectrum.size()[-1] + + dftBuf = torch.zeros(dft_len_ws) + + _index_ns = torch.tensor( + [index_ns] + [(index_ns + i) % samplesLen_ns for i in range(1, measure_len_ws)] + ) + dftBuf[:measure_len_ws] = samples[_index_ns] * spectrum_window[:measure_len_ws] + + # memset(c->dftBuf + i, 0, (p->dft_len_ws - i) * sizeof(*c->dftBuf)); + dftBuf[measure_len_ws:dft_len_ws].zero_() + + # lsx_safe_rdft((int)p->dft_len_ws, 1, c->dftBuf); + _dftBuf = torch.fft.rfft(dftBuf) + + # memset(c->dftBuf, 0, p->spectrum_start * sizeof(*c->dftBuf)); + _dftBuf[:spectrum_start].zero_() + + mult: float = ( + boot_count / (1.0 + boot_count) if boot_count >= 0 else measure_smooth_time_mult + ) + + _d = _dftBuf[spectrum_start:spectrum_end].abs() + spectrum[spectrum_start:spectrum_end].mul_(mult).add_(_d * (1 - mult)) + _d = spectrum[spectrum_start:spectrum_end] ** 2 + + _zeros = torch.zeros(spectrum_end - spectrum_start) + _mult = ( + _zeros + if boot_count >= 0 + else torch.where( + _d > noise_spectrum[spectrum_start:spectrum_end], + torch.tensor(noise_up_time_mult), # if + torch.tensor(noise_down_time_mult), # else + ) + ) + + noise_spectrum[spectrum_start:spectrum_end].mul_(_mult).add_(_d * (1 - _mult)) + _d = torch.sqrt( + torch.max( + _zeros, + _d - noise_reduction_amount * noise_spectrum[spectrum_start:spectrum_end], + ) + ) + + _cepstrum_Buf: Tensor = torch.zeros(dft_len_ws >> 1) + _cepstrum_Buf[spectrum_start:spectrum_end] = _d * cepstrum_window + _cepstrum_Buf[spectrum_end:dft_len_ws >> 1].zero_() + + # lsx_safe_rdft((int)p->dft_len_ws >> 1, 1, c->dftBuf); + _cepstrum_Buf = torch.fft.rfft(_cepstrum_Buf) + + result: float = float( + torch.sum(_cepstrum_Buf[cepstrum_start:cepstrum_end].abs().pow(2)) + ) + result = ( + math.log(result / (cepstrum_end - cepstrum_start)) if result > 0 else -math.inf + ) + return max(0, 21 + result) + + +def vad( + waveform: Tensor, + sample_rate: int, + trigger_level: float = 7.0, + trigger_time: float = 0.25, + search_time: float = 1.0, + allowed_gap: float = 0.25, + pre_trigger_time: float = 0.0, + # Fine-tuning parameters + boot_time: float = 0.35, + noise_up_time: float = 0.1, + noise_down_time: float = 0.01, + noise_reduction_amount: float = 1.35, + measure_freq: float = 20.0, + measure_duration: Optional[float] = None, + measure_smooth_time: float = 0.4, + hp_filter_freq: float = 50.0, + lp_filter_freq: float = 6000.0, + hp_lifter_freq: float = 150.0, + lp_lifter_freq: float = 2000.0, +) -> Tensor: + r"""Voice Activity Detector. Similar to SoX implementation. + Attempts to trim silence and quiet background sounds from the ends of recordings of speech. + The algorithm currently uses a simple cepstral power measurement to detect voice, + so may be fooled by other things, especially music. + + The effect can trim only from the front of the audio, + so in order to trim from the back, the reverse effect must also be used. + + Args: + waveform (Tensor): Tensor of audio of dimension `(channels, time)` or `(time)` + Tensor of shape `(channels, time)` is treated as a multi-channel recording + of the same event and the resulting output will be trimmed to the earliest + voice activity in any channel. + sample_rate (int): Sample rate of audio signal. + trigger_level (float, optional): The measurement level used to trigger activity detection. + This may need to be cahnged depending on the noise level, signal level, + and other characteristics of the input audio. (Default: 7.0) + trigger_time (float, optional): The time constant (in seconds) + used to help ignore short bursts of sound. (Default: 0.25) + search_time (float, optional): The amount of audio (in seconds) + to search for quieter/shorter bursts of audio to include prior + to the detected trigger point. (Default: 1.0) + allowed_gap (float, optional): The allowed gap (in seconds) between + quieter/shorter bursts of audio to include prior + to the detected trigger point. (Default: 0.25) + pre_trigger_time (float, optional): The amount of audio (in seconds) to preserve + before the trigger point and any found quieter/shorter bursts. (Default: 0.0) + boot_time (float, optional) The algorithm (internally) uses adaptive noise + estimation/reduction in order to detect the start of the wanted audio. + This option sets the time for the initial noise estimate. (Default: 0.35) + noise_up_time (float, optional) Time constant used by the adaptive noise estimator + for when the noise level is increasing. (Default: 0.1) + noise_down_time (float, optional) Time constant used by the adaptive noise estimator + for when the noise level is decreasing. (Default: 0.01) + noise_reduction_amount (float, optional) Amount of noise reduction to use in + the detection algorithm (e.g. 0, 0.5, ...). (Default: 1.35) + measure_freq (float, optional) Frequency of the algorithm’s + processing/measurements. (Default: 20.0) + measure_duration: (float, optional) Measurement duration. + (Default: Twice the measurement period; i.e. with overlap.) + measure_smooth_time (float, optional) Time constant used to smooth + spectral measurements. (Default: 0.4) + hp_filter_freq (float, optional) "Brick-wall" frequency of high-pass filter applied + at the input to the detector algorithm. (Default: 50.0) + lp_filter_freq (float, optional) "Brick-wall" frequency of low-pass filter applied + at the input to the detector algorithm. (Default: 6000.0) + hp_lifter_freq (float, optional) "Brick-wall" frequency of high-pass lifter used + in the detector algorithm. (Default: 150.0) + lp_lifter_freq (float, optional) "Brick-wall" frequency of low-pass lifter used + in the detector algorithm. (Default: 2000.0) + + Returns: + Tensor: Tensor of audio of dimension `(..., time)`. + + Reference: + - http://sox.sourceforge.net/sox.html + """ + + if waveform.ndim > 2: + warnings.warn( + "Expected input tensor dimension of 1 for single channel" + f" or 2 for multi-channel. Got {waveform.ndim} instead. " + "Batch semantics is not supported. " + "Please refer to https://github.com/pytorch/audio/issues/1348" + " and https://github.com/pytorch/audio/issues/1468." + ) + + measure_duration: float = ( + 2.0 / measure_freq if measure_duration is None else measure_duration + ) + + measure_len_ws = int(sample_rate * measure_duration + 0.5) + measure_len_ns = measure_len_ws + # for (dft_len_ws = 16; dft_len_ws < measure_len_ws; dft_len_ws <<= 1); + dft_len_ws = 16 + while dft_len_ws < measure_len_ws: + dft_len_ws *= 2 + + measure_period_ns = int(sample_rate / measure_freq + 0.5) + measures_len = math.ceil(search_time * measure_freq) + search_pre_trigger_len_ns = measures_len * measure_period_ns + gap_len = int(allowed_gap * measure_freq + 0.5) + + fixed_pre_trigger_len_ns = int(pre_trigger_time * sample_rate + 0.5) + samplesLen_ns = ( + fixed_pre_trigger_len_ns + search_pre_trigger_len_ns + measure_len_ns + ) + + spectrum_window = torch.zeros(measure_len_ws) + for i in range(measure_len_ws): + # sox.h:741 define SOX_SAMPLE_MIN (sox_sample_t)SOX_INT_MIN(32) + spectrum_window[i] = 2.0 / math.sqrt(float(measure_len_ws)) + # lsx_apply_hann(spectrum_window, (int)measure_len_ws); + spectrum_window *= torch.hann_window(measure_len_ws, dtype=torch.float) + + spectrum_start: int = int(hp_filter_freq / sample_rate * dft_len_ws + 0.5) + spectrum_start: int = max(spectrum_start, 1) + spectrum_end: int = int(lp_filter_freq / sample_rate * dft_len_ws + 0.5) + spectrum_end: int = min(spectrum_end, dft_len_ws // 2) + + cepstrum_window = torch.zeros(spectrum_end - spectrum_start) + for i in range(spectrum_end - spectrum_start): + cepstrum_window[i] = 2.0 / math.sqrt(float(spectrum_end) - spectrum_start) + # lsx_apply_hann(cepstrum_window,(int)(spectrum_end - spectrum_start)); + cepstrum_window *= torch.hann_window( + spectrum_end - spectrum_start, dtype=torch.float + ) + + cepstrum_start = math.ceil(sample_rate * 0.5 / lp_lifter_freq) + cepstrum_end = math.floor(sample_rate * 0.5 / hp_lifter_freq) + cepstrum_end = min(cepstrum_end, dft_len_ws // 4) + + assert cepstrum_end > cepstrum_start + + noise_up_time_mult = math.exp(-1.0 / (noise_up_time * measure_freq)) + noise_down_time_mult = math.exp(-1.0 / (noise_down_time * measure_freq)) + measure_smooth_time_mult = math.exp(-1.0 / (measure_smooth_time * measure_freq)) + trigger_meas_time_mult = math.exp(-1.0 / (trigger_time * measure_freq)) + + boot_count_max = int(boot_time * measure_freq - 0.5) + measure_timer_ns = measure_len_ns + boot_count = measures_index = flushedLen_ns = samplesIndex_ns = 0 + + # pack batch + shape = waveform.size() + waveform = waveform.view(-1, shape[-1]) + + n_channels, ilen = waveform.size() + + mean_meas = torch.zeros(n_channels) + samples = torch.zeros(n_channels, samplesLen_ns) + spectrum = torch.zeros(n_channels, dft_len_ws) + noise_spectrum = torch.zeros(n_channels, dft_len_ws) + measures = torch.zeros(n_channels, measures_len) + + has_triggered: bool = False + num_measures_to_flush: int = 0 + pos: int = 0 + + while pos < ilen and not has_triggered: + measure_timer_ns -= 1 + for i in range(n_channels): + samples[i, samplesIndex_ns] = waveform[i, pos] + # if (!p->measure_timer_ns) { + if measure_timer_ns == 0: + index_ns: int = ( + samplesIndex_ns + samplesLen_ns - measure_len_ns + ) % samplesLen_ns + meas: float = _measure( + measure_len_ws=measure_len_ws, + samples=samples[i], + spectrum=spectrum[i], + noise_spectrum=noise_spectrum[i], + spectrum_window=spectrum_window, + spectrum_start=spectrum_start, + spectrum_end=spectrum_end, + cepstrum_window=cepstrum_window, + cepstrum_start=cepstrum_start, + cepstrum_end=cepstrum_end, + noise_reduction_amount=noise_reduction_amount, + measure_smooth_time_mult=measure_smooth_time_mult, + noise_up_time_mult=noise_up_time_mult, + noise_down_time_mult=noise_down_time_mult, + index_ns=index_ns, + boot_count=boot_count, + ) + measures[i, measures_index] = meas + mean_meas[i] = mean_meas[i] * trigger_meas_time_mult + meas * ( + 1.0 - trigger_meas_time_mult + ) + + has_triggered = has_triggered or (mean_meas[i] >= trigger_level) + if has_triggered: + n: int = measures_len + k: int = measures_index + jTrigger: int = n + jZero: int = n + j: int = 0 + + for j in range(n): + if (measures[i, k] >= trigger_level) and ( + j <= jTrigger + gap_len + ): + jZero = jTrigger = j + elif (measures[i, k] == 0) and (jTrigger >= jZero): + jZero = j + k = (k + n - 1) % n + j = min(j, jZero) + # num_measures_to_flush = range_limit(j, num_measures_to_flush, n); + num_measures_to_flush = min(max(num_measures_to_flush, j), n) + # end if has_triggered + # end if (measure_timer_ns == 0): + # end for + samplesIndex_ns += 1 + pos += 1 + # end while + if samplesIndex_ns == samplesLen_ns: + samplesIndex_ns = 0 + if measure_timer_ns == 0: + measure_timer_ns = measure_period_ns + measures_index += 1 + measures_index = measures_index % measures_len + if boot_count >= 0: + boot_count = -1 if boot_count == boot_count_max else boot_count + 1 + + if has_triggered: + flushedLen_ns = (measures_len - num_measures_to_flush) * measure_period_ns + samplesIndex_ns = (samplesIndex_ns + flushedLen_ns) % samplesLen_ns + + res = waveform[:, pos - samplesLen_ns + flushedLen_ns:] + # unpack batch + return res.view(shape[:-1] + res.shape[-1:]) diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..1444ecef07d6dd54c317a53263d1207af4bb7b0f --- /dev/null +++ b/torchaudio/functional/functional.py @@ -0,0 +1,1812 @@ +# -*- coding: utf-8 -*- + +from collections.abc import Sequence +import io +import math +import warnings +from typing import Optional, Tuple + +import torch +from torch import Tensor +from torchaudio._internal import module_utils as _mod_utils +import torchaudio + +__all__ = [ + "spectrogram", + "inverse_spectrogram", + "griffinlim", + "amplitude_to_DB", + "DB_to_amplitude", + "compute_deltas", + "compute_kaldi_pitch", + "create_fb_matrix", + "melscale_fbanks", + "linear_fbanks", + "create_dct", + "compute_deltas", + "detect_pitch_frequency", + "DB_to_amplitude", + "mu_law_encoding", + "mu_law_decoding", + "complex_norm", + "angle", + "magphase", + "phase_vocoder", + 'mask_along_axis', + 'mask_along_axis_iid', + 'sliding_window_cmn', + "spectral_centroid", + "apply_codec", + "resample", + "edit_distance", + "pitch_shift", + "rnnt_loss", +] + + +def spectrogram( + waveform: Tensor, + pad: int, + window: Tensor, + n_fft: int, + hop_length: int, + win_length: int, + power: Optional[float], + normalized: bool, + center: bool = True, + pad_mode: str = "reflect", + onesided: bool = True, + return_complex: bool = True, +) -> Tensor: + r"""Create a spectrogram or a batch of spectrograms from a raw audio signal. + The spectrogram can be either magnitude-only or complex. + + Args: + waveform (Tensor): Tensor of audio of dimension `(..., time)` + pad (int): Two sided padding of signal + window (Tensor): Window tensor that is applied/multiplied to each frame/window + n_fft (int): Size of FFT + hop_length (int): Length of hop between STFT windows + win_length (int): Window size + power (float or None): Exponent for the magnitude spectrogram, + (must be > 0) e.g., 1 for energy, 2 for power, etc. + If None, then the complex spectrum is returned instead. + normalized (bool): Whether to normalize by magnitude after stft + center (bool, optional): whether to pad :attr:`waveform` on both sides so + that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. + Default: ``True`` + pad_mode (string, optional): controls the padding method used when + :attr:`center` is ``True``. Default: ``"reflect"`` + onesided (bool, optional): controls whether to return half of results to + avoid redundancy. Default: ``True`` + return_complex (bool, optional): + Indicates whether the resulting complex-valued Tensor should be represented with + native complex dtype, such as `torch.cfloat` and `torch.cdouble`, or real dtype + mimicking complex value with an extra dimension for real and imaginary parts. + (See also ``torch.view_as_real``.) + This argument is only effective when ``power=None``. It is ignored for + cases where ``power`` is a number as in those cases, the returned tensor is + power spectrogram, which is a real-valued tensor. + + Returns: + Tensor: Dimension `(..., freq, time)`, freq is + ``n_fft // 2 + 1`` and ``n_fft`` is the number of + Fourier bins, and time is the number of window hops (n_frame). + """ + if power is None and not return_complex: + warnings.warn( + "The use of pseudo complex type in spectrogram is now deprecated." + "Please migrate to native complex type by providing `return_complex=True`. " + "Please refer to https://github.com/pytorch/audio/issues/1337 " + "for more details about torchaudio's plan to migrate to native complex type." + ) + + if pad > 0: + # TODO add "with torch.no_grad():" back when JIT supports it + waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant") + + # pack batch + shape = waveform.size() + waveform = waveform.reshape(-1, shape[-1]) + + # default values are consistent with librosa.core.spectrum._spectrogram + spec_f = torch.stft( + input=waveform, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + center=center, + pad_mode=pad_mode, + normalized=False, + onesided=onesided, + return_complex=True, + ) + + # unpack batch + spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-2:]) + + if normalized: + spec_f /= window.pow(2.).sum().sqrt() + if power is not None: + if power == 1.0: + return spec_f.abs() + return spec_f.abs().pow(power) + if not return_complex: + return torch.view_as_real(spec_f) + return spec_f + + +def inverse_spectrogram( + spectrogram: Tensor, + length: Optional[int], + pad: int, + window: Tensor, + n_fft: int, + hop_length: int, + win_length: int, + normalized: bool, + center: bool = True, + pad_mode: str = "reflect", + onesided: bool = True, +) -> Tensor: + r"""Create an inverse spectrogram or a batch of inverse spectrograms from the provided + complex-valued spectrogram. + + Args: + spectrogram (Tensor): Complex tensor of audio of dimension (..., freq, time). + length (int or None): The output length of the waveform. + pad (int): Two sided padding of signal. It is only effective when ``length`` is provided. + window (Tensor): Window tensor that is applied/multiplied to each frame/window + n_fft (int): Size of FFT + hop_length (int): Length of hop between STFT windows + win_length (int): Window size + normalized (bool): Whether the stft output was normalized by magnitude + center (bool, optional): whether the waveform was padded on both sides so + that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. + Default: ``True`` + pad_mode (string, optional): controls the padding method used when + :attr:`center` is ``True``. This parameter is provided for compatibility with the + spectrogram function and is not used. Default: ``"reflect"`` + onesided (bool, optional): controls whether spectrogram was done in onesided mode. + Default: ``True`` + + Returns: + Tensor: Dimension `(..., time)`. Least squares estimation of the original signal. + """ + + if spectrogram.dtype == torch.float32 or spectrogram.dtype == torch.float64: + warnings.warn( + "The use of pseudo complex type in inverse_spectrogram is now deprecated. " + "Please migrate to native complex type by using a complex tensor as input. " + "If the input is generated via spectrogram() function or transform, please use " + "return_complex=True as an argument to that function. " + "Please refer to https://github.com/pytorch/audio/issues/1337 " + "for more details about torchaudio's plan to migrate to native complex type." + ) + spectrogram = torch.view_as_complex(spectrogram) + + if normalized: + spectrogram = spectrogram * window.pow(2.).sum().sqrt() + + # pack batch + shape = spectrogram.size() + spectrogram = spectrogram.reshape(-1, shape[-2], shape[-1]) + + # default values are consistent with librosa.core.spectrum._spectrogram + waveform = torch.istft( + input=spectrogram, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + center=center, + normalized=False, + onesided=onesided, + length=length + 2 * pad if length is not None else None, + return_complex=False, + ) + + if length is not None and pad > 0: + # remove padding from front and back + waveform = waveform[:, pad:-pad] + + # unpack batch + waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:]) + + return waveform + + +def _get_complex_dtype(real_dtype: torch.dtype): + if real_dtype == torch.double: + return torch.cdouble + if real_dtype == torch.float: + return torch.cfloat + if real_dtype == torch.half: + return torch.complex32 + raise ValueError(f'Unexpected dtype {real_dtype}') + + +def griffinlim( + specgram: Tensor, + window: Tensor, + n_fft: int, + hop_length: int, + win_length: int, + power: float, + n_iter: int, + momentum: float, + length: Optional[int], + rand_init: bool +) -> Tensor: + r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation. + + Implementation ported from + *librosa* [:footcite:`brian_mcfee-proc-scipy-2015`], *A fast Griffin-Lim algorithm* [:footcite:`6701851`] + and *Signal estimation from modified short-time Fourier transform* [:footcite:`1172092`]. + + Args: + specgram (Tensor): A magnitude-only STFT spectrogram of dimension `(..., freq, frames)` + where freq is ``n_fft // 2 + 1``. + window (Tensor): Window tensor that is applied/multiplied to each frame/window + n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins + hop_length (int): Length of hop between STFT windows. ( + Default: ``win_length // 2``) + win_length (int): Window size. (Default: ``n_fft``) + power (float): Exponent for the magnitude spectrogram, + (must be > 0) e.g., 1 for energy, 2 for power, etc. + n_iter (int): Number of iteration for phase recovery process. + momentum (float): The momentum parameter for fast Griffin-Lim. + Setting this to 0 recovers the original Griffin-Lim method. + Values near 1 can lead to faster convergence, but above 1 may not converge. + length (int or None): Array length of the expected output. + rand_init (bool): Initializes phase randomly if True, to zero otherwise. + + Returns: + Tensor: waveform of `(..., time)`, where time equals the ``length`` parameter if given. + """ + assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum) + assert momentum >= 0, 'momentum={} < 0'.format(momentum) + + # pack batch + shape = specgram.size() + specgram = specgram.reshape([-1] + list(shape[-2:])) + + specgram = specgram.pow(1 / power) + + # initialize the phase + if rand_init: + angles = torch.rand( + specgram.size(), + dtype=_get_complex_dtype(specgram.dtype), device=specgram.device) + else: + angles = torch.full( + specgram.size(), 1, + dtype=_get_complex_dtype(specgram.dtype), device=specgram.device) + + # And initialize the previous iterate to 0 + tprev = torch.tensor(0., dtype=specgram.dtype, device=specgram.device) + for _ in range(n_iter): + # Invert with our current estimate of the phases + inverse = torch.istft(specgram * angles, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + length=length) + + # Rebuild the spectrogram + rebuilt = torch.stft( + input=inverse, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + center=True, + pad_mode='reflect', + normalized=False, + onesided=True, + return_complex=True, + ) + + # Update our phase estimates + angles = rebuilt + if momentum: + angles = angles - tprev.mul_(momentum / (1 + momentum)) + angles = angles.div(angles.abs().add(1e-16)) + + # Store the previous iterate + tprev = rebuilt + + # Return the final phase estimates + waveform = torch.istft(specgram * angles, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + length=length) + + # unpack batch + waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:]) + + return waveform + + +def amplitude_to_DB( + x: Tensor, + multiplier: float, + amin: float, + db_multiplier: float, + top_db: Optional[float] = None +) -> Tensor: + r"""Turn a spectrogram from the power/amplitude scale to the decibel scale. + + The output of each tensor in a batch depends on the maximum value of that tensor, + and so may return different values for an audio clip split into snippets vs. a full clip. + + Args: + + x (Tensor): Input spectrogram(s) before being converted to decibel scale. Input should take + the form `(..., freq, time)`. Batched inputs should include a channel dimension and + have the form `(batch, channel, freq, time)`. + multiplier (float): Use 10. for power and 20. for amplitude + amin (float): Number to clamp ``x`` + db_multiplier (float): Log10(max(reference value and amin)) + top_db (float or None, optional): Minimum negative cut-off in decibels. A reasonable number + is 80. (Default: ``None``) + + Returns: + Tensor: Output tensor in decibel scale + """ + x_db = multiplier * torch.log10(torch.clamp(x, min=amin)) + x_db -= multiplier * db_multiplier + + if top_db is not None: + # Expand batch + shape = x_db.size() + packed_channels = shape[-3] if x_db.dim() > 2 else 1 + x_db = x_db.reshape(-1, packed_channels, shape[-2], shape[-1]) + + x_db = torch.max(x_db, (x_db.amax(dim=(-3, -2, -1)) - top_db).view(-1, 1, 1, 1)) + + # Repack batch + x_db = x_db.reshape(shape) + + return x_db + + +def DB_to_amplitude( + x: Tensor, + ref: float, + power: float +) -> Tensor: + r"""Turn a tensor from the decibel scale to the power/amplitude scale. + + Args: + x (Tensor): Input tensor before being converted to power/amplitude scale. + ref (float): Reference which the output will be scaled by. + power (float): If power equals 1, will compute DB to power. If 0.5, will compute DB to amplitude. + + Returns: + Tensor: Output tensor in power/amplitude scale. + """ + return ref * torch.pow(torch.pow(10.0, 0.1 * x), power) + + +def _hz_to_mel(freq: float, mel_scale: str = "htk") -> float: + r"""Convert Hz to Mels. + + Args: + freqs (float): Frequencies in Hz + mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) + + Returns: + mels (float): Frequency in Mels + """ + + if mel_scale not in ['slaney', 'htk']: + raise ValueError('mel_scale should be one of "htk" or "slaney".') + + if mel_scale == "htk": + return 2595.0 * math.log10(1.0 + (freq / 700.0)) + + # Fill in the linear part + f_min = 0.0 + f_sp = 200.0 / 3 + + mels = (freq - f_min) / f_sp + + # Fill in the log-scale part + min_log_hz = 1000.0 + min_log_mel = (min_log_hz - f_min) / f_sp + logstep = math.log(6.4) / 27.0 + + if freq >= min_log_hz: + mels = min_log_mel + math.log(freq / min_log_hz) / logstep + + return mels + + +def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor: + """Convert mel bin numbers to frequencies. + + Args: + mels (Tensor): Mel frequencies + mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) + + Returns: + freqs (Tensor): Mels converted in Hz + """ + + if mel_scale not in ['slaney', 'htk']: + raise ValueError('mel_scale should be one of "htk" or "slaney".') + + if mel_scale == "htk": + return 700.0 * (10.0**(mels / 2595.0) - 1.0) + + # Fill in the linear scale + f_min = 0.0 + f_sp = 200.0 / 3 + freqs = f_min + f_sp * mels + + # And now the nonlinear scale + min_log_hz = 1000.0 + min_log_mel = (min_log_hz - f_min) / f_sp + logstep = math.log(6.4) / 27.0 + + log_t = (mels >= min_log_mel) + freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel)) + + return freqs + + +def _create_triangular_filterbank( + all_freqs: Tensor, + f_pts: Tensor, +) -> Tensor: + """Create a triangular filter bank. + + Args: + all_freqs (Tensor): STFT freq points of size (`n_freqs`). + f_pts (Tensor): Filter mid points of size (`n_filter`). + + Returns: + fb (Tensor): The filter bank of size (`n_freqs`, `n_filter`). + """ + # Adopted from Librosa + # calculate the difference between each filter mid point and each stft freq point in hertz + f_diff = f_pts[1:] - f_pts[:-1] # (n_filter + 1) + slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_filter + 2) + # create overlapping triangles + zero = torch.zeros(1) + down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_filter) + up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_filter) + fb = torch.max(zero, torch.min(down_slopes, up_slopes)) + + return fb + + +def create_fb_matrix( + n_freqs: int, + f_min: float, + f_max: float, + n_mels: int, + sample_rate: int, + norm: Optional[str] = None, + mel_scale: str = "htk", +) -> Tensor: + r"""Create a frequency bin conversion matrix. + + Args: + n_freqs (int): Number of frequencies to highlight/apply + f_min (float): Minimum frequency (Hz) + f_max (float): Maximum frequency (Hz) + n_mels (int): Number of mel filterbanks + sample_rate (int): Sample rate of the audio waveform + norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band + (area normalization). (Default: ``None``) + mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) + + Returns: + Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``) + meaning number of frequencies to highlight/apply to x the number of filterbanks. + Each column is a filterbank so that assuming there is a matrix A of + size (..., ``n_freqs``), the applied result would be + ``A * create_fb_matrix(A.size(-1), ...)``. + """ + warnings.warn( + "The use of `create_fb_matrix` is now deprecated and will be removed in " + "the 0.11 release. " + "Please migrate your code to use `melscale_fbanks` instead. " + "For more information, please refer to https://github.com/pytorch/audio/issues/1574." + ) + + return melscale_fbanks( + n_freqs=n_freqs, + f_min=f_min, + f_max=f_max, + n_mels=n_mels, + sample_rate=sample_rate, + norm=norm, + mel_scale=mel_scale + ) + + +def melscale_fbanks( + n_freqs: int, + f_min: float, + f_max: float, + n_mels: int, + sample_rate: int, + norm: Optional[str] = None, + mel_scale: str = "htk", +) -> Tensor: + r"""Create a frequency bin conversion matrix. + + Note: + For the sake of the numerical compatibility with librosa, not all the coefficients + in the resulting filter bank has magnitude of 1. + + .. image:: https://download.pytorch.org/torchaudio/doc-assets/mel_fbanks.png + :alt: Visualization of generated filter bank + + Args: + n_freqs (int): Number of frequencies to highlight/apply + f_min (float): Minimum frequency (Hz) + f_max (float): Maximum frequency (Hz) + n_mels (int): Number of mel filterbanks + sample_rate (int): Sample rate of the audio waveform + norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band + (area normalization). (Default: ``None``) + mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) + + Returns: + Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``) + meaning number of frequencies to highlight/apply to x the number of filterbanks. + Each column is a filterbank so that assuming there is a matrix A of + size (..., ``n_freqs``), the applied result would be + ``A * melscale_fbanks(A.size(-1), ...)``. + + """ + + if norm is not None and norm != "slaney": + raise ValueError("norm must be one of None or 'slaney'") + + # freq bins + all_freqs = torch.linspace(0, sample_rate // 2, n_freqs) + + # calculate mel freq bins + m_min = _hz_to_mel(f_min, mel_scale=mel_scale) + m_max = _hz_to_mel(f_max, mel_scale=mel_scale) + + m_pts = torch.linspace(m_min, m_max, n_mels + 2) + f_pts = _mel_to_hz(m_pts, mel_scale=mel_scale) + + # create filterbank + fb = _create_triangular_filterbank(all_freqs, f_pts) + + if norm is not None and norm == "slaney": + # Slaney-style mel is scaled to be approx constant energy per channel + enorm = 2.0 / (f_pts[2:n_mels + 2] - f_pts[:n_mels]) + fb *= enorm.unsqueeze(0) + + if (fb.max(dim=0).values == 0.).any(): + warnings.warn( + "At least one mel filterbank has all zero values. " + f"The value for `n_mels` ({n_mels}) may be set too high. " + f"Or, the value for `n_freqs` ({n_freqs}) may be set too low." + ) + + return fb + + +def linear_fbanks( + n_freqs: int, + f_min: float, + f_max: float, + n_filter: int, + sample_rate: int, +) -> Tensor: + r"""Creates a linear triangular filterbank. + + Note: + For the sake of the numerical compatibility with librosa, not all the coefficients + in the resulting filter bank has magnitude of 1. + + .. image:: https://download.pytorch.org/torchaudio/doc-assets/lin_fbanks.png + :alt: Visualization of generated filter bank + + Args: + n_freqs (int): Number of frequencies to highlight/apply + f_min (float): Minimum frequency (Hz) + f_max (float): Maximum frequency (Hz) + n_filter (int): Number of (linear) triangular filter + sample_rate (int): Sample rate of the audio waveform + + Returns: + Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_filter``) + meaning number of frequencies to highlight/apply to x the number of filterbanks. + Each column is a filterbank so that assuming there is a matrix A of + size (..., ``n_freqs``), the applied result would be + ``A * linear_fbanks(A.size(-1), ...)``. + """ + # freq bins + all_freqs = torch.linspace(0, sample_rate // 2, n_freqs) + + # filter mid-points + f_pts = torch.linspace(f_min, f_max, n_filter + 2) + + # create filterbank + fb = _create_triangular_filterbank(all_freqs, f_pts) + + return fb + + +def create_dct( + n_mfcc: int, + n_mels: int, + norm: Optional[str] +) -> Tensor: + r"""Create a DCT transformation matrix with shape (``n_mels``, ``n_mfcc``), + normalized depending on norm. + + Args: + n_mfcc (int): Number of mfc coefficients to retain + n_mels (int): Number of mel filterbanks + norm (str or None): Norm to use (either 'ortho' or None) + + Returns: + Tensor: The transformation matrix, to be right-multiplied to + row-wise data of size (``n_mels``, ``n_mfcc``). + """ + # http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II + n = torch.arange(float(n_mels)) + k = torch.arange(float(n_mfcc)).unsqueeze(1) + dct = torch.cos(math.pi / float(n_mels) * (n + 0.5) * k) # size (n_mfcc, n_mels) + if norm is None: + dct *= 2.0 + else: + assert norm == "ortho" + dct[0] *= 1.0 / math.sqrt(2.0) + dct *= math.sqrt(2.0 / float(n_mels)) + return dct.t() + + +def mu_law_encoding( + x: Tensor, + quantization_channels: int +) -> Tensor: + r"""Encode signal based on mu-law companding. For more info see the + `Wikipedia Entry `_ + + This algorithm assumes the signal has been scaled to between -1 and 1 and + returns a signal encoded with values from 0 to quantization_channels - 1. + + Args: + x (Tensor): Input tensor + quantization_channels (int): Number of channels + + Returns: + Tensor: Input after mu-law encoding + """ + mu = quantization_channels - 1.0 + if not x.is_floating_point(): + x = x.to(torch.float) + mu = torch.tensor(mu, dtype=x.dtype) + x_mu = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu) + x_mu = ((x_mu + 1) / 2 * mu + 0.5).to(torch.int64) + return x_mu + + +def mu_law_decoding( + x_mu: Tensor, + quantization_channels: int +) -> Tensor: + r"""Decode mu-law encoded signal. For more info see the + `Wikipedia Entry `_ + + This expects an input with values between 0 and quantization_channels - 1 + and returns a signal scaled between -1 and 1. + + Args: + x_mu (Tensor): Input tensor + quantization_channels (int): Number of channels + + Returns: + Tensor: Input after mu-law decoding + """ + mu = quantization_channels - 1.0 + if not x_mu.is_floating_point(): + x_mu = x_mu.to(torch.float) + mu = torch.tensor(mu, dtype=x_mu.dtype) + x = ((x_mu) / mu) * 2 - 1.0 + x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu + return x + + +@_mod_utils.deprecated( + "Please convert the input Tensor to complex type with `torch.view_as_complex` then " + "use `torch.abs`. " + "Please refer to https://github.com/pytorch/audio/issues/1337 " + "for more details about torchaudio's plan to migrate to native complex type.", + version="0.11", +) +def complex_norm( + complex_tensor: Tensor, + power: float = 1.0 +) -> Tensor: + r"""Compute the norm of complex tensor input. + + Args: + complex_tensor (Tensor): Tensor shape of `(..., complex=2)` + power (float, optional): Power of the norm. (Default: `1.0`). + + Returns: + Tensor: Power of the normed input tensor. Shape of `(..., )` + """ + + # Replace by torch.norm once issue is fixed + # https://github.com/pytorch/pytorch/issues/34279 + return complex_tensor.pow(2.).sum(-1).pow(0.5 * power) + + +@_mod_utils.deprecated( + "Please convert the input Tensor to complex type with `torch.view_as_complex` then " + "use `torch.angle`. " + "Please refer to https://github.com/pytorch/audio/issues/1337 " + "for more details about torchaudio's plan to migrate to native complex type.", + version="0.11", +) +def angle( + complex_tensor: Tensor +) -> Tensor: + r"""Compute the angle of complex tensor input. + + Args: + complex_tensor (Tensor): Tensor shape of `(..., complex=2)` + + Return: + Tensor: Angle of a complex tensor. Shape of `(..., )` + """ + return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0]) + + +@_mod_utils.deprecated( + "Please convert the input Tensor to complex type with `torch.view_as_complex` then " + "use `torch.abs` and `torch.angle`. " + "Please refer to https://github.com/pytorch/audio/issues/1337 " + "for more details about torchaudio's plan to migrate to native complex type.", + version="0.11", +) +def magphase( + complex_tensor: Tensor, + power: float = 1.0 +) -> Tuple[Tensor, Tensor]: + r"""Separate a complex-valued spectrogram with shape `(..., 2)` into its magnitude and phase. + + Args: + complex_tensor (Tensor): Tensor shape of `(..., complex=2)` + power (float, optional): Power of the norm. (Default: `1.0`) + + Returns: + (Tensor, Tensor): The magnitude and phase of the complex tensor + """ + mag = complex_norm(complex_tensor, power) + phase = angle(complex_tensor) + return mag, phase + + +def phase_vocoder( + complex_specgrams: Tensor, + rate: float, + phase_advance: Tensor +) -> Tensor: + r"""Given a STFT tensor, speed up in time without modifying pitch by a + factor of ``rate``. + + Args: + complex_specgrams (Tensor): + Either a real tensor of dimension of `(..., freq, num_frame, complex=2)` + or a tensor of dimension `(..., freq, num_frame)` with complex dtype. + rate (float): Speed-up factor + phase_advance (Tensor): Expected phase advance in each bin. Dimension of `(freq, 1)` + + Returns: + Tensor: + Stretched spectrogram. The resulting tensor is of the same dtype as the input + spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``. + + Example - With Tensor of complex dtype + >>> freq, hop_length = 1025, 512 + >>> # (channel, freq, time) + >>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat) + >>> rate = 1.3 # Speed up by 30% + >>> phase_advance = torch.linspace( + >>> 0, math.pi * hop_length, freq)[..., None] + >>> x = phase_vocoder(complex_specgrams, rate, phase_advance) + >>> x.shape # with 231 == ceil(300 / 1.3) + torch.Size([2, 1025, 231]) + + Example - With Tensor of real dtype and extra dimension for complex field + >>> freq, hop_length = 1025, 512 + >>> # (channel, freq, time, complex=2) + >>> complex_specgrams = torch.randn(2, freq, 300, 2) + >>> rate = 1.3 # Speed up by 30% + >>> phase_advance = torch.linspace( + >>> 0, math.pi * hop_length, freq)[..., None] + >>> x = phase_vocoder(complex_specgrams, rate, phase_advance) + >>> x.shape # with 231 == ceil(300 / 1.3) + torch.Size([2, 1025, 231, 2]) + """ + if rate == 1.0: + return complex_specgrams + + if not complex_specgrams.is_complex(): + warnings.warn( + "The support for pseudo complex type in `torchaudio.functional.phase_vocoder` and " + "`torchaudio.transforms.TimeStretch` is now deprecated and will be removed " + "from 0.11 release." + "Please migrate to native complex type by converting the input tensor with " + "`torch.view_as_complex`. " + "Please refer to https://github.com/pytorch/audio/issues/1337 " + "for more details about torchaudio's plan to migrate to native complex type." + ) + if complex_specgrams.size(-1) != 2: + raise ValueError( + "complex_specgrams must be either native complex tensors or " + "real valued tensors with shape (..., 2)") + + is_complex = complex_specgrams.is_complex() + + if not is_complex: + complex_specgrams = torch.view_as_complex(complex_specgrams) + + # pack batch + shape = complex_specgrams.size() + complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-2:])) + + # Figures out the corresponding real dtype, i.e. complex128 -> float64, complex64 -> float32 + # Note torch.real is a view so it does not incur any memory copy. + real_dtype = torch.real(complex_specgrams).dtype + time_steps = torch.arange( + 0, + complex_specgrams.size(-1), + rate, + device=complex_specgrams.device, + dtype=real_dtype) + + alphas = time_steps % 1.0 + phase_0 = complex_specgrams[..., :1].angle() + + # Time Padding + complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 2]) + + # (new_bins, freq, 2) + complex_specgrams_0 = complex_specgrams.index_select(-1, time_steps.long()) + complex_specgrams_1 = complex_specgrams.index_select(-1, (time_steps + 1).long()) + + angle_0 = complex_specgrams_0.angle() + angle_1 = complex_specgrams_1.angle() + + norm_0 = complex_specgrams_0.abs() + norm_1 = complex_specgrams_1.abs() + + phase = angle_1 - angle_0 - phase_advance + phase = phase - 2 * math.pi * torch.round(phase / (2 * math.pi)) + + # Compute Phase Accum + phase = phase + phase_advance + phase = torch.cat([phase_0, phase[..., :-1]], dim=-1) + phase_acc = torch.cumsum(phase, -1) + + mag = alphas * norm_1 + (1 - alphas) * norm_0 + + complex_specgrams_stretch = torch.polar(mag, phase_acc) + + # unpack batch + complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-2] + complex_specgrams_stretch.shape[1:]) + + if not is_complex: + return torch.view_as_real(complex_specgrams_stretch) + return complex_specgrams_stretch + + +def mask_along_axis_iid( + specgrams: Tensor, + mask_param: int, + mask_value: float, + axis: int +) -> Tensor: + r""" + Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where + ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``. + + Args: + specgrams (Tensor): Real spectrograms `(batch, channel, freq, time)` + mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] + mask_value (float): Value to assign to the masked columns + axis (int): Axis to apply masking on (2 -> frequency, 3 -> time) + + Returns: + Tensor: Masked spectrograms of dimensions `(batch, channel, freq, time)` + """ + + if axis not in [2, 3]: + raise ValueError('Only Frequency and Time masking are supported') + + device = specgrams.device + dtype = specgrams.dtype + + value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * mask_param + min_value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * (specgrams.size(axis) - value) + + # Create broadcastable mask + mask_start = min_value[..., None, None] + mask_end = (min_value + value)[..., None, None] + mask = torch.arange(0, specgrams.size(axis), device=device, dtype=dtype) + + # Per batch example masking + specgrams = specgrams.transpose(axis, -1) + specgrams = specgrams.masked_fill((mask >= mask_start) & (mask < mask_end), mask_value) + specgrams = specgrams.transpose(axis, -1) + + return specgrams + + +def mask_along_axis( + specgram: Tensor, + mask_param: int, + mask_value: float, + axis: int +) -> Tensor: + r""" + Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where + ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``. + All examples will have the same mask interval. + + Args: + specgram (Tensor): Real spectrogram `(channel, freq, time)` + mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] + mask_value (float): Value to assign to the masked columns + axis (int): Axis to apply masking on (1 -> frequency, 2 -> time) + + Returns: + Tensor: Masked spectrogram of dimensions `(channel, freq, time)` + """ + if axis not in [1, 2]: + raise ValueError('Only Frequency and Time masking are supported') + + # pack batch + shape = specgram.size() + specgram = specgram.reshape([-1] + list(shape[-2:])) + value = torch.rand(1) * mask_param + min_value = torch.rand(1) * (specgram.size(axis) - value) + + mask_start = (min_value.long()).squeeze() + mask_end = (min_value.long() + value.long()).squeeze() + mask = torch.arange(0, specgram.shape[axis], device=specgram.device, dtype=specgram.dtype) + mask = (mask >= mask_start) & (mask < mask_end) + if axis == 1: + mask = mask.unsqueeze(-1) + + assert mask_end - mask_start < mask_param + + specgram = specgram.masked_fill(mask, mask_value) + + # unpack batch + specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:]) + + return specgram + + +def compute_deltas( + specgram: Tensor, + win_length: int = 5, + mode: str = "replicate" +) -> Tensor: + r"""Compute delta coefficients of a tensor, usually a spectrogram: + + .. math:: + d_t = \frac{\sum_{n=1}^{\text{N}} n (c_{t+n} - c_{t-n})}{2 \sum_{n=1}^{\text{N}} n^2} + + where :math:`d_t` is the deltas at time :math:`t`, + :math:`c_t` is the spectrogram coeffcients at time :math:`t`, + :math:`N` is ``(win_length-1)//2``. + + Args: + specgram (Tensor): Tensor of audio of dimension `(..., freq, time)` + win_length (int, optional): The window length used for computing delta (Default: ``5``) + mode (str, optional): Mode parameter passed to padding (Default: ``"replicate"``) + + Returns: + Tensor: Tensor of deltas of dimension `(..., freq, time)` + + Example + >>> specgram = torch.randn(1, 40, 1000) + >>> delta = compute_deltas(specgram) + >>> delta2 = compute_deltas(delta) + """ + device = specgram.device + dtype = specgram.dtype + + # pack batch + shape = specgram.size() + specgram = specgram.reshape(1, -1, shape[-1]) + + assert win_length >= 3 + + n = (win_length - 1) // 2 + + # twice sum of integer squared + denom = n * (n + 1) * (2 * n + 1) / 3 + + specgram = torch.nn.functional.pad(specgram, (n, n), mode=mode) + + kernel = torch.arange(-n, n + 1, 1, device=device, dtype=dtype).repeat(specgram.shape[1], 1, 1) + + output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom + + # unpack batch + output = output.reshape(shape) + + return output + + +def _compute_nccf( + waveform: Tensor, + sample_rate: int, + frame_time: float, + freq_low: int +) -> Tensor: + r""" + Compute Normalized Cross-Correlation Function (NCCF). + + .. math:: + \phi_i(m) = \frac{\sum_{n=b_i}^{b_i + N-1} w(n) w(m+n)}{\sqrt{E(b_i) E(m+b_i)}}, + + where + :math:`\phi_i(m)` is the NCCF at frame :math:`i` with lag :math:`m`, + :math:`w` is the waveform, + :math:`N` is the length of a frame, + :math:`b_i` is the beginning of frame :math:`i`, + :math:`E(j)` is the energy :math:`\sum_{n=j}^{j+N-1} w^2(n)`. + """ + + EPSILON = 10 ** (-9) + + # Number of lags to check + lags = int(math.ceil(sample_rate / freq_low)) + + frame_size = int(math.ceil(sample_rate * frame_time)) + + waveform_length = waveform.size()[-1] + num_of_frames = int(math.ceil(waveform_length / frame_size)) + + p = lags + num_of_frames * frame_size - waveform_length + waveform = torch.nn.functional.pad(waveform, (0, p)) + + # Compute lags + output_lag = [] + for lag in range(1, lags + 1): + s1 = waveform[..., :-lag].unfold(-1, frame_size, frame_size)[..., :num_of_frames, :] + s2 = waveform[..., lag:].unfold(-1, frame_size, frame_size)[..., :num_of_frames, :] + + output_frames = ( + (s1 * s2).sum(-1) + / (EPSILON + torch.norm(s1, p=2, dim=-1)).pow(2) + / (EPSILON + torch.norm(s2, p=2, dim=-1)).pow(2) + ) + + output_lag.append(output_frames.unsqueeze(-1)) + + nccf = torch.cat(output_lag, -1) + + return nccf + + +def _combine_max( + a: Tuple[Tensor, Tensor], + b: Tuple[Tensor, Tensor], + thresh: float = 0.99 +) -> Tuple[Tensor, Tensor]: + """ + Take value from first if bigger than a multiplicative factor of the second, elementwise. + """ + mask = (a[0] > thresh * b[0]) + values = mask * a[0] + ~mask * b[0] + indices = mask * a[1] + ~mask * b[1] + return values, indices + + +def _find_max_per_frame( + nccf: Tensor, + sample_rate: int, + freq_high: int +) -> Tensor: + r""" + For each frame, take the highest value of NCCF, + apply centered median smoothing, and convert to frequency. + + Note: If the max among all the lags is very close + to the first half of lags, then the latter is taken. + """ + + lag_min = int(math.ceil(sample_rate / freq_high)) + + # Find near enough max that is smallest + + best = torch.max(nccf[..., lag_min:], -1) + + half_size = nccf.shape[-1] // 2 + half = torch.max(nccf[..., lag_min:half_size], -1) + + best = _combine_max(half, best) + indices = best[1] + + # Add back minimal lag + indices += lag_min + # Add 1 empirical calibration offset + indices += 1 + + return indices + + +def _median_smoothing( + indices: Tensor, + win_length: int +) -> Tensor: + r""" + Apply median smoothing to the 1D tensor over the given window. + """ + + # Centered windowed + pad_length = (win_length - 1) // 2 + + # "replicate" padding in any dimension + indices = torch.nn.functional.pad( + indices, (pad_length, 0), mode="constant", value=0. + ) + + indices[..., :pad_length] = torch.cat(pad_length * [indices[..., pad_length].unsqueeze(-1)], dim=-1) + roll = indices.unfold(-1, win_length, 1) + + values, _ = torch.median(roll, -1) + return values + + +def detect_pitch_frequency( + waveform: Tensor, + sample_rate: int, + frame_time: float = 10 ** (-2), + win_length: int = 30, + freq_low: int = 85, + freq_high: int = 3400, +) -> Tensor: + r"""Detect pitch frequency. + + It is implemented using normalized cross-correlation function and median smoothing. + + Args: + waveform (Tensor): Tensor of audio of dimension `(..., freq, time)` + sample_rate (int): The sample rate of the waveform (Hz) + frame_time (float, optional): Duration of a frame (Default: ``10 ** (-2)``). + win_length (int, optional): The window length for median smoothing (in number of frames) (Default: ``30``). + freq_low (int, optional): Lowest frequency that can be detected (Hz) (Default: ``85``). + freq_high (int, optional): Highest frequency that can be detected (Hz) (Default: ``3400``). + + Returns: + Tensor: Tensor of freq of dimension `(..., frame)` + """ + # pack batch + shape = list(waveform.size()) + waveform = waveform.reshape([-1] + shape[-1:]) + + nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low) + indices = _find_max_per_frame(nccf, sample_rate, freq_high) + indices = _median_smoothing(indices, win_length) + + # Convert indices to frequency + EPSILON = 10 ** (-9) + freq = sample_rate / (EPSILON + indices.to(torch.float)) + + # unpack batch + freq = freq.reshape(shape[:-1] + list(freq.shape[-1:])) + + return freq + + +def sliding_window_cmn( + specgram: Tensor, + cmn_window: int = 600, + min_cmn_window: int = 100, + center: bool = False, + norm_vars: bool = False, +) -> Tensor: + r""" + Apply sliding-window cepstral mean (and optionally variance) normalization per utterance. + + Args: + specgram (Tensor): Tensor of spectrogram of dimension `(..., time, freq)` + cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600) + min_cmn_window (int, optional): Minimum CMN window used at start of decoding (adds latency only at start). + Only applicable if center == false, ignored if center==true (int, default = 100) + center (bool, optional): If true, use a window centered on the current frame + (to the extent possible, modulo end effects). If false, window is to the left. (bool, default = false) + norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false) + + Returns: + Tensor: Tensor matching input shape `(..., freq, time)` + """ + input_shape = specgram.shape + num_frames, num_feats = input_shape[-2:] + specgram = specgram.view(-1, num_frames, num_feats) + num_channels = specgram.shape[0] + + dtype = specgram.dtype + device = specgram.device + last_window_start = last_window_end = -1 + cur_sum = torch.zeros(num_channels, num_feats, dtype=dtype, device=device) + cur_sumsq = torch.zeros(num_channels, num_feats, dtype=dtype, device=device) + cmn_specgram = torch.zeros( + num_channels, num_frames, num_feats, dtype=dtype, device=device) + for t in range(num_frames): + window_start = 0 + window_end = 0 + if center: + window_start = t - cmn_window // 2 + window_end = window_start + cmn_window + else: + window_start = t - cmn_window + window_end = t + 1 + if window_start < 0: + window_end -= window_start + window_start = 0 + if not center: + if window_end > t: + window_end = max(t + 1, min_cmn_window) + if window_end > num_frames: + window_start -= (window_end - num_frames) + window_end = num_frames + if window_start < 0: + window_start = 0 + if last_window_start == -1: + input_part = specgram[:, window_start: window_end - window_start, :] + cur_sum += torch.sum(input_part, 1) + if norm_vars: + cur_sumsq += torch.cumsum(input_part ** 2, 1)[:, -1, :] + else: + if window_start > last_window_start: + frame_to_remove = specgram[:, last_window_start, :] + cur_sum -= frame_to_remove + if norm_vars: + cur_sumsq -= (frame_to_remove ** 2) + if window_end > last_window_end: + frame_to_add = specgram[:, last_window_end, :] + cur_sum += frame_to_add + if norm_vars: + cur_sumsq += (frame_to_add ** 2) + window_frames = window_end - window_start + last_window_start = window_start + last_window_end = window_end + cmn_specgram[:, t, :] = specgram[:, t, :] - cur_sum / window_frames + if norm_vars: + if window_frames == 1: + cmn_specgram[:, t, :] = torch.zeros( + num_channels, num_feats, dtype=dtype, device=device) + else: + variance = cur_sumsq + variance = variance / window_frames + variance -= ((cur_sum ** 2) / (window_frames ** 2)) + variance = torch.pow(variance, -0.5) + cmn_specgram[:, t, :] *= variance + + cmn_specgram = cmn_specgram.view(input_shape[:-2] + (num_frames, num_feats)) + if len(input_shape) == 2: + cmn_specgram = cmn_specgram.squeeze(0) + return cmn_specgram + + +def spectral_centroid( + waveform: Tensor, + sample_rate: int, + pad: int, + window: Tensor, + n_fft: int, + hop_length: int, + win_length: int, +) -> Tensor: + r""" + Compute the spectral centroid for each channel along the time axis. + + The spectral centroid is defined as the weighted average of the + frequency values, weighted by their magnitude. + + Args: + waveform (Tensor): Tensor of audio of dimension `(..., time)` + sample_rate (int): Sample rate of the audio waveform + pad (int): Two sided padding of signal + window (Tensor): Window tensor that is applied/multiplied to each frame/window + n_fft (int): Size of FFT + hop_length (int): Length of hop between STFT windows + win_length (int): Window size + + Returns: + Tensor: Dimension `(..., time)` + """ + specgram = spectrogram(waveform, pad=pad, window=window, n_fft=n_fft, hop_length=hop_length, + win_length=win_length, power=1., normalized=False) + freqs = torch.linspace(0, sample_rate // 2, steps=1 + n_fft // 2, + device=specgram.device).reshape((-1, 1)) + freq_dim = -2 + return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim) + + +@_mod_utils.requires_sox() +def apply_codec( + waveform: Tensor, + sample_rate: int, + format: str, + channels_first: bool = True, + compression: Optional[float] = None, + encoding: Optional[str] = None, + bits_per_sample: Optional[int] = None, +) -> Tensor: + r""" + Apply codecs as a form of augmentation. + + Args: + waveform (Tensor): Audio data. Must be 2 dimensional. See also ```channels_first```. + sample_rate (int): Sample rate of the audio waveform. + format (str): File format. + channels_first (bool, optional): + When True, both the input and output Tensor have dimension `(channel, time)`. + Otherwise, they have dimension `(time, channel)`. + compression (float or None, optional): Used for formats other than WAV. + For more details see :py:func:`torchaudio.backend.sox_io_backend.save`. + encoding (str or None, optional): Changes the encoding for the supported formats. + For more details see :py:func:`torchaudio.backend.sox_io_backend.save`. + bits_per_sample (int or None, optional): Changes the bit depth for the supported formats. + For more details see :py:func:`torchaudio.backend.sox_io_backend.save`. + + Returns: + Tensor: Resulting Tensor. + If ``channels_first=True``, it has `(channel, time)` else `(time, channel)`. + """ + bytes = io.BytesIO() + torchaudio.backend.sox_io_backend.save(bytes, + waveform, + sample_rate, + channels_first, + compression, + format, + encoding, + bits_per_sample + ) + bytes.seek(0) + augmented, _ = torchaudio.sox_effects.sox_effects.apply_effects_file( + bytes, effects=[["rate", f"{sample_rate}"]], channels_first=channels_first, format=format) + return augmented + + +@_mod_utils.requires_kaldi() +def compute_kaldi_pitch( + waveform: torch.Tensor, + sample_rate: float, + frame_length: float = 25.0, + frame_shift: float = 10.0, + min_f0: float = 50, + max_f0: float = 400, + soft_min_f0: float = 10.0, + penalty_factor: float = 0.1, + lowpass_cutoff: float = 1000, + resample_frequency: float = 4000, + delta_pitch: float = 0.005, + nccf_ballast: float = 7000, + lowpass_filter_width: int = 1, + upsample_filter_width: int = 5, + max_frames_latency: int = 0, + frames_per_chunk: int = 0, + simulate_first_pass_online: bool = False, + recompute_frame: int = 500, + snip_edges: bool = True, +) -> torch.Tensor: + """Extract pitch based on method described in *A pitch extraction algorithm tuned + for automatic speech recognition* [:footcite:`6854049`]. + + This function computes the equivalent of `compute-kaldi-pitch-feats` from Kaldi. + + Args: + waveform (Tensor): + The input waveform of shape `(..., time)`. + sample_rate (float): + Sample rate of `waveform`. + frame_length (float, optional): + Frame length in milliseconds. (default: 25.0) + frame_shift (float, optional): + Frame shift in milliseconds. (default: 10.0) + min_f0 (float, optional): + Minimum F0 to search for (Hz) (default: 50.0) + max_f0 (float, optional): + Maximum F0 to search for (Hz) (default: 400.0) + soft_min_f0 (float, optional): + Minimum f0, applied in soft way, must not exceed min-f0 (default: 10.0) + penalty_factor (float, optional): + Cost factor for FO change. (default: 0.1) + lowpass_cutoff (float, optional): + Cutoff frequency for LowPass filter (Hz) (default: 1000) + resample_frequency (float, optional): + Frequency that we down-sample the signal to. Must be more than twice lowpass-cutoff. + (default: 4000) + delta_pitch( float, optional): + Smallest relative change in pitch that our algorithm measures. (default: 0.005) + nccf_ballast (float, optional): + Increasing this factor reduces NCCF for quiet frames (default: 7000) + lowpass_filter_width (int, optional): + Integer that determines filter width of lowpass filter, more gives sharper filter. + (default: 1) + upsample_filter_width (int, optional): + Integer that determines filter width when upsampling NCCF. (default: 5) + max_frames_latency (int, optional): + Maximum number of frames of latency that we allow pitch tracking to introduce into + the feature processing (affects output only if ``frames_per_chunk > 0`` and + ``simulate_first_pass_online=True``) (default: 0) + frames_per_chunk (int, optional): + The number of frames used for energy normalization. (default: 0) + simulate_first_pass_online (bool, optional): + If true, the function will output features that correspond to what an online decoder + would see in the first pass of decoding -- not the final version of the features, + which is the default. (default: False) + Relevant if ``frames_per_chunk > 0``. + recompute_frame (int, optional): + Only relevant for compatibility with online pitch extraction. + A non-critical parameter; the frame at which we recompute some of the forward pointers, + after revising our estimate of the signal energy. + Relevant if ``frames_per_chunk > 0``. (default: 500) + snip_edges (bool, optional): + If this is set to false, the incomplete frames near the ending edge won't be snipped, + so that the number of frames is the file size divided by the frame-shift. + This makes different types of features give the same number of frames. (default: True) + + Returns: + Tensor: Pitch feature. Shape: `(batch, frames 2)` where the last dimension + corresponds to pitch and NCCF. + """ + shape = waveform.shape + waveform = waveform.reshape(-1, shape[-1]) + result = torch.ops.torchaudio.kaldi_ComputeKaldiPitch( + waveform, sample_rate, frame_length, frame_shift, + min_f0, max_f0, soft_min_f0, penalty_factor, lowpass_cutoff, + resample_frequency, delta_pitch, nccf_ballast, + lowpass_filter_width, upsample_filter_width, max_frames_latency, + frames_per_chunk, simulate_first_pass_online, recompute_frame, + snip_edges, + ) + result = result.reshape(shape[:-1] + result.shape[-2:]) + return result + + +def _get_sinc_resample_kernel( + orig_freq: int, + new_freq: int, + gcd: int, + lowpass_filter_width: int, + rolloff: float, + resampling_method: str, + beta: Optional[float], + device: torch.device = torch.device("cpu"), + dtype: Optional[torch.dtype] = None): + + if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq): + raise Exception( + "Frequencies must be of integer type to ensure quality resampling computation. " + "To work around this, manually convert both frequencies to integer values " + "that maintain their resampling rate ratio before passing them into the function. " + "Example: To downsample a 44100 hz waveform by a factor of 8, use " + "`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5`. " + "For more information, please refer to https://github.com/pytorch/audio/issues/1487." + ) + + if resampling_method not in ['sinc_interpolation', 'kaiser_window']: + raise ValueError('Invalid resampling method: {}'.format(resampling_method)) + + orig_freq = int(orig_freq) // gcd + new_freq = int(new_freq) // gcd + + assert lowpass_filter_width > 0 + kernels = [] + base_freq = min(orig_freq, new_freq) + # This will perform antialiasing filtering by removing the highest frequencies. + # At first I thought I only needed this when downsampling, but when upsampling + # you will get edge artifacts without this, as the edge is equivalent to zero padding, + # which will add high freq artifacts. + base_freq *= rolloff + + # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor) + # using the sinc interpolation formula: + # x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t)) + # We can then sample the function x(t) with a different sample rate: + # y[j] = x(j / new_freq) + # or, + # y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq)) + + # We see here that y[j] is the convolution of x[i] with a specific filter, for which + # we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing. + # But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq]. + # Indeed: + # y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq)) + # = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq)) + # = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq)) + # so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`. + # This will explain the F.conv1d after, with a stride of orig_freq. + width = math.ceil(lowpass_filter_width * orig_freq / base_freq) + # If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e., + # they will have a lot of almost zero values to the left or to the right... + # There is probably a way to evaluate those filters more efficiently, but this is kept for + # future work. + idx_dtype = dtype if dtype is not None else torch.float64 + idx = torch.arange(-width, width + orig_freq, device=device, dtype=idx_dtype) + + for i in range(new_freq): + t = (-i / new_freq + idx / orig_freq) * base_freq + t = t.clamp_(-lowpass_filter_width, lowpass_filter_width) + + # we do not use built in torch windows here as we need to evaluate the window + # at specific positions, not over a regular grid. + if resampling_method == "sinc_interpolation": + window = torch.cos(t * math.pi / lowpass_filter_width / 2)**2 + else: + # kaiser_window + if beta is None: + beta = 14.769656459379492 + beta_tensor = torch.tensor(float(beta)) + window = torch.i0(beta_tensor * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta_tensor) + t *= math.pi + kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t) + kernel.mul_(window) + kernels.append(kernel) + + scale = base_freq / orig_freq + kernels = torch.stack(kernels).view(new_freq, 1, -1).mul_(scale) + if dtype is None: + kernels = kernels.to(dtype=torch.float32) + return kernels, width + + +def _apply_sinc_resample_kernel( + waveform: Tensor, + orig_freq: int, + new_freq: int, + gcd: int, + kernel: Tensor, + width: int, +): + orig_freq = int(orig_freq) // gcd + new_freq = int(new_freq) // gcd + + # pack batch + shape = waveform.size() + waveform = waveform.view(-1, shape[-1]) + + num_wavs, length = waveform.shape + waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq)) + resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq) + resampled = resampled.transpose(1, 2).reshape(num_wavs, -1) + target_length = int(math.ceil(new_freq * length / orig_freq)) + resampled = resampled[..., :target_length] + + # unpack batch + resampled = resampled.view(shape[:-1] + resampled.shape[-1:]) + return resampled + + +def resample( + waveform: Tensor, + orig_freq: int, + new_freq: int, + lowpass_filter_width: int = 6, + rolloff: float = 0.99, + resampling_method: str = "sinc_interpolation", + beta: Optional[float] = None, +) -> Tensor: + r"""Resamples the waveform at the new frequency using bandlimited interpolation. + + https://ccrma.stanford.edu/~jos/resample/Theory_Ideal_Bandlimited_Interpolation.html + + Note: + ``transforms.Resample`` precomputes and reuses the resampling kernel, so using it will result in + more efficient computation if resampling multiple waveforms with the same resampling parameters. + + Args: + waveform (Tensor): The input signal of dimension `(..., time)` + orig_freq (int): The original frequency of the signal + new_freq (int): The desired frequency + lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper + but less efficient. (Default: ``6``) + rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist. + Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``) + resampling_method (str, optional): The resampling method to use. + Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``) + beta (float or None, optional): The shape parameter used for kaiser window. + + Returns: + Tensor: The waveform at the new frequency of dimension `(..., time).` + """ + + assert orig_freq > 0.0 and new_freq > 0.0 + + if orig_freq == new_freq: + return waveform + + gcd = math.gcd(int(orig_freq), int(new_freq)) + + kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, gcd, lowpass_filter_width, rolloff, + resampling_method, beta, waveform.device, waveform.dtype) + resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width) + return resampled + + +@torch.jit.unused +def edit_distance(seq1: Sequence, seq2: Sequence) -> int: + """ + Calculate the word level edit (Levenshtein) distance between two sequences. + + The function computes an edit distance allowing deletion, insertion and + substitution. The result is an integer. + + For most applications, the two input sequences should be the same type. If + two strings are given, the output is the edit distance between the two + strings (character edit distance). If two lists of strings are given, the + output is the edit distance between sentences (word edit distance). Users + may want to normalize the output by the length of the reference sequence. + + torchscipt is not supported for this function. + + Args: + seq1 (Sequence): the first sequence to compare. + seq2 (Sequence): the second sequence to compare. + Returns: + int: The distance between the first and second sequences. + """ + len_sent2 = len(seq2) + dold = list(range(len_sent2 + 1)) + dnew = [0 for _ in range(len_sent2 + 1)] + + for i in range(1, len(seq1) + 1): + dnew[0] = i + for j in range(1, len_sent2 + 1): + if seq1[i - 1] == seq2[j - 1]: + dnew[j] = dold[j - 1] + else: + substitution = dold[j - 1] + 1 + insertion = dnew[j - 1] + 1 + deletion = dold[j] + 1 + dnew[j] = min(substitution, insertion, deletion) + + dnew, dold = dold, dnew + + return int(dold[-1]) + + +def pitch_shift( + waveform: Tensor, + sample_rate: int, + n_steps: int, + bins_per_octave: int = 12, + n_fft: int = 512, + win_length: Optional[int] = None, + hop_length: Optional[int] = None, + window: Optional[Tensor] = None, +) -> Tensor: + """ + Shift the pitch of a waveform by ``n_steps`` steps. + + Args: + waveform (Tensor): The input waveform of shape `(..., time)`. + sample_rate (int): Sample rate of `waveform`. + n_steps (int): The (fractional) steps to shift `waveform`. + bins_per_octave (int, optional): The number of steps per octave (Default: ``12``). + n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins (Default: ``512``). + win_length (int or None, optional): Window size. If None, then ``n_fft`` is used. (Default: ``None``). + hop_length (int or None, optional): Length of hop between STFT windows. If None, then + ``win_length // 4`` is used (Default: ``None``). + window (Tensor or None, optional): Window tensor that is applied/multiplied to each frame/window. + If None, then ``torch.hann_window(win_length)`` is used (Default: ``None``). + + + Returns: + Tensor: The pitch-shifted audio waveform of shape `(..., time)`. + """ + if hop_length is None: + hop_length = n_fft // 4 + if win_length is None: + win_length = n_fft + if window is None: + window = torch.hann_window(window_length=win_length, device=waveform.device) + + # pack batch + shape = waveform.size() + waveform = waveform.reshape(-1, shape[-1]) + + ori_len = shape[-1] + rate = 2.0 ** (-float(n_steps) / bins_per_octave) + spec_f = torch.stft(input=waveform, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + center=True, + pad_mode='reflect', + normalized=False, + onesided=True, + return_complex=True) + phase_advance = torch.linspace(0, math.pi * hop_length, spec_f.shape[-2], device=spec_f.device)[..., None] + spec_stretch = phase_vocoder(spec_f, rate, phase_advance) + len_stretch = int(round(ori_len / rate)) + waveform_stretch = torch.istft(spec_stretch, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + length=len_stretch) + waveform_shift = resample(waveform_stretch, int(sample_rate / rate), sample_rate) + shift_len = waveform_shift.size()[-1] + if shift_len > ori_len: + waveform_shift = waveform_shift[..., :ori_len] + else: + waveform_shift = torch.nn.functional.pad(waveform_shift, [0, ori_len - shift_len]) + + # unpack batch + waveform_shift = waveform_shift.view(shape[:-1] + waveform_shift.shape[-1:]) + return waveform_shift + + +def rnnt_loss( + logits: Tensor, + targets: Tensor, + logit_lengths: Tensor, + target_lengths: Tensor, + blank: int = -1, + clamp: float = -1, + reduction: str = "mean", +): + """Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks* + [:footcite:`graves2012sequence`]. + The RNN Transducer loss extends the CTC loss by defining a distribution over output + sequences of all lengths, and by jointly modelling both input-output and output-output + dependencies. + + Args: + logits (Tensor): Tensor of dimension `(batch, max seq length, max target length + 1, class)` + containing output from joiner + targets (Tensor): Tensor of dimension `(batch, max target length)` containing targets with zero padded + logit_lengths (Tensor): Tensor of dimension `(batch)` containing lengths of each sequence from encoder + target_lengths (Tensor): Tensor of dimension `(batch)` containing lengths of targets for each sequence + blank (int, optional): blank label (Default: ``-1``) + clamp (float, optional): clamp for gradients (Default: ``-1``) + reduction (string, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``) + Returns: + Tensor: Loss with the reduction option applied. If ``reduction`` is ``'none'``, then size `(batch)`, + otherwise scalar. + """ + if reduction not in ['none', 'mean', 'sum']: + raise ValueError("reduction should be one of 'none', 'mean', or 'sum'") + + if blank < 0: # reinterpret blank index if blank < 0. + blank = logits.shape[-1] + blank + + costs, _ = torch.ops.torchaudio.rnnt_loss( + logits=logits, + targets=targets, + logit_lengths=logit_lengths, + target_lengths=target_lengths, + blank=blank, + clamp=clamp, + ) + + if reduction == 'mean': + return costs.mean() + elif reduction == 'sum': + return costs.sum() + + return costs diff --git a/torchaudio/kaldi_io.py b/torchaudio/kaldi_io.py new file mode 100644 index 0000000000000000000000000000000000000000..ba1689da2b12e303fd44e0a70de9fc325f2ff6bc --- /dev/null +++ b/torchaudio/kaldi_io.py @@ -0,0 +1,130 @@ +# To use this file, the dependency (https://github.com/vesis84/kaldi-io-for-python) +# needs to be installed. This is a light wrapper around kaldi_io that returns +# torch.Tensors. +from typing import Any, Callable, Iterable, Tuple + +import torch +from torch import Tensor +from torchaudio._internal import module_utils as _mod_utils + +if _mod_utils.is_module_available('kaldi_io', 'numpy'): + import numpy as np + import kaldi_io + + +__all__ = [ + 'read_vec_int_ark', + 'read_vec_flt_scp', + 'read_vec_flt_ark', + 'read_mat_scp', + 'read_mat_ark', +] + + +def _convert_method_output_to_tensor(file_or_fd: Any, + fn: Callable, + convert_contiguous: bool = False) -> Iterable[Tuple[str, Tensor]]: + r"""Takes a method invokes it. The output is converted to a tensor. + + Args: + file_or_fd (str/FileDescriptor): File name or file descriptor + fn (Callable): Function that has the signature (file name/descriptor) and converts it to + Iterable[Tuple[str, Tensor]]. + convert_contiguous (bool, optional): Determines whether the array should be converted into a + contiguous layout. (Default: ``False``) + + Returns: + Iterable[Tuple[str, Tensor]]: The string is the key and the tensor is vec/mat + """ + for key, np_arr in fn(file_or_fd): + if convert_contiguous: + np_arr = np.ascontiguousarray(np_arr) + yield key, torch.from_numpy(np_arr) + + +@_mod_utils.requires_module('kaldi_io', 'numpy') +def read_vec_int_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]: + r"""Create generator of (key,vector) tuples, which reads from the ark file/stream. + + Args: + file_or_fd (str/FileDescriptor): ark, gzipped ark, pipe or opened file descriptor + + Returns: + Iterable[Tuple[str, Tensor]]: The string is the key and the tensor is the vector read from file + + Example + >>> # read ark to a 'dictionary' + >>> d = { u:d for u,d in torchaudio.kaldi_io.read_vec_int_ark(file) } + """ + # Requires convert_contiguous to be True because elements from int32 vector are + # sorted in tuples: (sizeof(int32), value) so strides are (5,) instead of (4,) which will throw an error + # in from_numpy as it expects strides to be a multiple of 4 (int32). + return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_vec_int_ark, convert_contiguous=True) + + +@_mod_utils.requires_module('kaldi_io', 'numpy') +def read_vec_flt_scp(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]: + r"""Create generator of (key,vector) tuples, read according to Kaldi scp. + + Args: + file_or_fd (str/FileDescriptor): scp, gzipped scp, pipe or opened file descriptor + + Returns: + Iterable[Tuple[str, Tensor]]: The string is the key and the tensor is the vector read from file + + Example + >>> # read scp to a 'dictionary' + >>> # d = { u:d for u,d in torchaudio.kaldi_io.read_vec_flt_scp(file) } + """ + return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_vec_flt_scp) + + +@_mod_utils.requires_module('kaldi_io', 'numpy') +def read_vec_flt_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]: + r"""Create generator of (key,vector) tuples, which reads from the ark file/stream. + + Args: + file_or_fd (str/FileDescriptor): ark, gzipped ark, pipe or opened file descriptor + + Returns: + Iterable[Tuple[str, Tensor]]: The string is the key and the tensor is the vector read from file + + Example + >>> # read ark to a 'dictionary' + >>> d = { u:d for u,d in torchaudio.kaldi_io.read_vec_flt_ark(file) } + """ + return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_vec_flt_ark) + + +@_mod_utils.requires_module('kaldi_io', 'numpy') +def read_mat_scp(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]: + r"""Create generator of (key,matrix) tuples, read according to Kaldi scp. + + Args: + file_or_fd (str/FileDescriptor): scp, gzipped scp, pipe or opened file descriptor + + Returns: + Iterable[Tuple[str, Tensor]]: The string is the key and the tensor is the matrix read from file + + Example + >>> # read scp to a 'dictionary' + >>> d = { u:d for u,d in torchaudio.kaldi_io.read_mat_scp(file) } + """ + return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_mat_scp) + + +@_mod_utils.requires_module('kaldi_io', 'numpy') +def read_mat_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]: + r"""Create generator of (key,matrix) tuples, which reads from the ark file/stream. + + Args: + file_or_fd (str/FileDescriptor): ark, gzipped ark, pipe or opened file descriptor + + Returns: + Iterable[Tuple[str, Tensor]]: The string is the key and the tensor is the matrix read from file + + Example + >>> # read ark to a 'dictionary' + >>> d = { u:d for u,d in torchaudio.kaldi_io.read_mat_ark(file) } + """ + return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_mat_ark) diff --git a/torchaudio/models/__init__.py b/torchaudio/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..956fc82e2f334ea75f9c003681fed5008bad7a37 --- /dev/null +++ b/torchaudio/models/__init__.py @@ -0,0 +1,31 @@ +from .wav2letter import Wav2Letter +from .wavernn import WaveRNN +from .conv_tasnet import ConvTasNet +from .deepspeech import DeepSpeech +from .tacotron2 import Tacotron2 +from .wav2vec2 import ( + Wav2Vec2Model, + wav2vec2_model, + wav2vec2_base, + wav2vec2_large, + wav2vec2_large_lv60k, + hubert_base, + hubert_large, + hubert_xlarge, +) + +__all__ = [ + 'Wav2Letter', + 'WaveRNN', + 'ConvTasNet', + 'DeepSpeech', + 'Wav2Vec2Model', + 'wav2vec2_model', + 'wav2vec2_base', + 'wav2vec2_large', + 'wav2vec2_large_lv60k', + 'hubert_base', + 'hubert_large', + 'hubert_xlarge', + 'Tacotron2', +] diff --git a/torchaudio/models/conv_tasnet.py b/torchaudio/models/conv_tasnet.py new file mode 100644 index 0000000000000000000000000000000000000000..39d3dae02f94bca3d7927d550f4e5826b851871c --- /dev/null +++ b/torchaudio/models/conv_tasnet.py @@ -0,0 +1,321 @@ +"""Implements Conv-TasNet with building blocks of it. + +Based on https://github.com/naplab/Conv-TasNet/tree/e66d82a8f956a69749ec8a4ae382217faa097c5c +""" + +from typing import Tuple, Optional + +import torch + + +class ConvBlock(torch.nn.Module): + """1D Convolutional block. + + Args: + io_channels (int): The number of input/output channels, + hidden_channels (int): The number of channels in the internal layers, . + kernel_size (int): The convolution kernel size of the middle layer,

. + padding (int): Padding value of the convolution in the middle layer. + dilation (int, optional): Dilation value of the convolution in the middle layer. + no_redisual (bool, optional): Disable residual block/output. + + Note: + This implementation corresponds to the "non-causal" setting in the paper. + """ + + def __init__( + self, + io_channels: int, + hidden_channels: int, + kernel_size: int, + padding: int, + dilation: int = 1, + no_residual: bool = False, + ): + super().__init__() + + self.conv_layers = torch.nn.Sequential( + torch.nn.Conv1d( + in_channels=io_channels, out_channels=hidden_channels, kernel_size=1 + ), + torch.nn.PReLU(), + torch.nn.GroupNorm(num_groups=1, num_channels=hidden_channels, eps=1e-08), + torch.nn.Conv1d( + in_channels=hidden_channels, + out_channels=hidden_channels, + kernel_size=kernel_size, + padding=padding, + dilation=dilation, + groups=hidden_channels, + ), + torch.nn.PReLU(), + torch.nn.GroupNorm(num_groups=1, num_channels=hidden_channels, eps=1e-08), + ) + + self.res_out = ( + None + if no_residual + else torch.nn.Conv1d( + in_channels=hidden_channels, out_channels=io_channels, kernel_size=1 + ) + ) + self.skip_out = torch.nn.Conv1d( + in_channels=hidden_channels, out_channels=io_channels, kernel_size=1 + ) + + def forward( + self, input: torch.Tensor + ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: + feature = self.conv_layers(input) + if self.res_out is None: + residual = None + else: + residual = self.res_out(feature) + skip_out = self.skip_out(feature) + return residual, skip_out + + +class MaskGenerator(torch.nn.Module): + """TCN (Temporal Convolution Network) Separation Module + + Generates masks for separation. + + Args: + input_dim (int): Input feature dimension, . + num_sources (int): The number of sources to separate. + kernel_size (int): The convolution kernel size of conv blocks,

. + num_featrs (int): Input/output feature dimenstion of conv blocks, . + num_hidden (int): Intermediate feature dimention of conv blocks, + num_layers (int): The number of conv blocks in one stack, . + num_stacks (int): The number of conv block stacks, . + msk_activate (str): The activation function of the mask output. + + Note: + This implementation corresponds to the "non-causal" setting in the paper. + """ + + def __init__( + self, + input_dim: int, + num_sources: int, + kernel_size: int, + num_feats: int, + num_hidden: int, + num_layers: int, + num_stacks: int, + msk_activate: str, + ): + super().__init__() + + self.input_dim = input_dim + self.num_sources = num_sources + + self.input_norm = torch.nn.GroupNorm( + num_groups=1, num_channels=input_dim, eps=1e-8 + ) + self.input_conv = torch.nn.Conv1d( + in_channels=input_dim, out_channels=num_feats, kernel_size=1 + ) + + self.receptive_field = 0 + self.conv_layers = torch.nn.ModuleList([]) + for s in range(num_stacks): + for l in range(num_layers): + multi = 2 ** l + self.conv_layers.append( + ConvBlock( + io_channels=num_feats, + hidden_channels=num_hidden, + kernel_size=kernel_size, + dilation=multi, + padding=multi, + # The last ConvBlock does not need residual + no_residual=(l == (num_layers - 1) and s == (num_stacks - 1)), + ) + ) + self.receptive_field += ( + kernel_size if s == 0 and l == 0 else (kernel_size - 1) * multi + ) + self.output_prelu = torch.nn.PReLU() + self.output_conv = torch.nn.Conv1d( + in_channels=num_feats, out_channels=input_dim * num_sources, kernel_size=1, + ) + if msk_activate == "sigmoid": + self.mask_activate = torch.nn.Sigmoid() + elif msk_activate == "relu": + self.mask_activate = torch.nn.ReLU() + else: + raise ValueError(f"Unsupported activation {msk_activate}") + + def forward(self, input: torch.Tensor) -> torch.Tensor: + """Generate separation mask. + + Args: + input (torch.Tensor): 3D Tensor with shape [batch, features, frames] + + Returns: + Tensor: shape [batch, num_sources, features, frames] + """ + batch_size = input.shape[0] + feats = self.input_norm(input) + feats = self.input_conv(feats) + output = 0.0 + for layer in self.conv_layers: + residual, skip = layer(feats) + if residual is not None: # the last conv layer does not produce residual + feats = feats + residual + output = output + skip + output = self.output_prelu(output) + output = self.output_conv(output) + output = self.mask_activate(output) + return output.view(batch_size, self.num_sources, self.input_dim, -1) + + +class ConvTasNet(torch.nn.Module): + """Conv-TasNet: a fully-convolutional time-domain audio separation network + *Conv-TasNet: Surpassing Ideal Time–Frequency Magnitude Masking for Speech Separation* + [:footcite:`Luo_2019`]. + + Args: + num_sources (int, optional): The number of sources to split. + enc_kernel_size (int, optional): The convolution kernel size of the encoder/decoder, . + enc_num_feats (int, optional): The feature dimensions passed to mask generator, . + msk_kernel_size (int, optional): The convolution kernel size of the mask generator,

. + msk_num_feats (int, optional): The input/output feature dimension of conv block in the mask generator, . + msk_num_hidden_feats (int, optional): The internal feature dimension of conv block of the mask generator, . + msk_num_layers (int, optional): The number of layers in one conv block of the mask generator, . + msk_num_stacks (int, optional): The numbr of conv blocks of the mask generator, . + msk_activate (str, optional): The activation function of the mask output (Default: ``sigmoid``). + + Note: + This implementation corresponds to the "non-causal" setting in the paper. + """ + + def __init__( + self, + num_sources: int = 2, + # encoder/decoder parameters + enc_kernel_size: int = 16, + enc_num_feats: int = 512, + # mask generator parameters + msk_kernel_size: int = 3, + msk_num_feats: int = 128, + msk_num_hidden_feats: int = 512, + msk_num_layers: int = 8, + msk_num_stacks: int = 3, + msk_activate: str = "sigmoid", + ): + super().__init__() + + self.num_sources = num_sources + self.enc_num_feats = enc_num_feats + self.enc_kernel_size = enc_kernel_size + self.enc_stride = enc_kernel_size // 2 + + self.encoder = torch.nn.Conv1d( + in_channels=1, + out_channels=enc_num_feats, + kernel_size=enc_kernel_size, + stride=self.enc_stride, + padding=self.enc_stride, + bias=False, + ) + self.mask_generator = MaskGenerator( + input_dim=enc_num_feats, + num_sources=num_sources, + kernel_size=msk_kernel_size, + num_feats=msk_num_feats, + num_hidden=msk_num_hidden_feats, + num_layers=msk_num_layers, + num_stacks=msk_num_stacks, + msk_activate=msk_activate, + ) + self.decoder = torch.nn.ConvTranspose1d( + in_channels=enc_num_feats, + out_channels=1, + kernel_size=enc_kernel_size, + stride=self.enc_stride, + padding=self.enc_stride, + bias=False, + ) + + def _align_num_frames_with_strides( + self, input: torch.Tensor + ) -> Tuple[torch.Tensor, int]: + """Pad input Tensor so that the end of the input tensor corresponds with + + 1. (if kernel size is odd) the center of the last convolution kernel + or 2. (if kernel size is even) the end of the first half of the last convolution kernel + + Assumption: + The resulting Tensor will be padded with the size of stride (== kernel_width // 2) + on the both ends in Conv1D + + |<--- k_1 --->| + | | |<-- k_n-1 -->| + | | | |<--- k_n --->| + | | | | | + | | | | | + | v v v | + |<---->|<--- input signal --->|<--->|<---->| + stride PAD stride + + Args: + input (torch.Tensor): 3D Tensor with shape (batch_size, channels==1, frames) + + Returns: + Tensor: Padded Tensor + int: Number of paddings performed + """ + batch_size, num_channels, num_frames = input.shape + is_odd = self.enc_kernel_size % 2 + num_strides = (num_frames - is_odd) // self.enc_stride + num_remainings = num_frames - (is_odd + num_strides * self.enc_stride) + if num_remainings == 0: + return input, 0 + + num_paddings = self.enc_stride - num_remainings + pad = torch.zeros( + batch_size, + num_channels, + num_paddings, + dtype=input.dtype, + device=input.device, + ) + return torch.cat([input, pad], 2), num_paddings + + def forward(self, input: torch.Tensor) -> torch.Tensor: + """Perform source separation. Generate audio source waveforms. + + Args: + input (torch.Tensor): 3D Tensor with shape [batch, channel==1, frames] + + Returns: + Tensor: 3D Tensor with shape [batch, channel==num_sources, frames] + """ + if input.ndim != 3 or input.shape[1] != 1: + raise ValueError( + f"Expected 3D tensor (batch, channel==1, frames). Found: {input.shape}" + ) + + # B: batch size + # L: input frame length + # L': padded input frame length + # F: feature dimension + # M: feature frame length + # S: number of sources + + padded, num_pads = self._align_num_frames_with_strides(input) # B, 1, L' + batch_size, num_padded_frames = padded.shape[0], padded.shape[2] + feats = self.encoder(padded) # B, F, M + masked = self.mask_generator(feats) * feats.unsqueeze(1) # B, S, F, M + masked = masked.view( + batch_size * self.num_sources, self.enc_num_feats, -1 + ) # B*S, F, M + decoded = self.decoder(masked) # B*S, 1, L' + output = decoded.view( + batch_size, self.num_sources, num_padded_frames + ) # B, S, L' + if num_pads > 0: + output = output[..., :-num_pads] # B, S, L + return output diff --git a/torchaudio/models/deepspeech.py b/torchaudio/models/deepspeech.py new file mode 100644 index 0000000000000000000000000000000000000000..41efc07d9e234be220ffa21b4e885895861f9cee --- /dev/null +++ b/torchaudio/models/deepspeech.py @@ -0,0 +1,91 @@ +import torch + +__all__ = ["DeepSpeech"] + + +class FullyConnected(torch.nn.Module): + """ + Args: + n_feature: Number of input features + n_hidden: Internal hidden unit size. + """ + + def __init__(self, + n_feature: int, + n_hidden: int, + dropout: float, + relu_max_clip: int = 20) -> None: + super(FullyConnected, self).__init__() + self.fc = torch.nn.Linear(n_feature, n_hidden, bias=True) + self.relu_max_clip = relu_max_clip + self.dropout = dropout + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc(x) + x = torch.nn.functional.relu(x) + x = torch.nn.functional.hardtanh(x, 0, self.relu_max_clip) + if self.dropout: + x = torch.nn.functional.dropout(x, self.dropout, self.training) + return x + + +class DeepSpeech(torch.nn.Module): + """ + DeepSpeech model architecture from *Deep Speech: Scaling up end-to-end speech recognition* + [:footcite:`hannun2014deep`]. + + Args: + n_feature: Number of input features + n_hidden: Internal hidden unit size. + n_class: Number of output classes + """ + + def __init__( + self, + n_feature: int, + n_hidden: int = 2048, + n_class: int = 40, + dropout: float = 0.0, + ) -> None: + super(DeepSpeech, self).__init__() + self.n_hidden = n_hidden + self.fc1 = FullyConnected(n_feature, n_hidden, dropout) + self.fc2 = FullyConnected(n_hidden, n_hidden, dropout) + self.fc3 = FullyConnected(n_hidden, n_hidden, dropout) + self.bi_rnn = torch.nn.RNN( + n_hidden, n_hidden, num_layers=1, nonlinearity="relu", bidirectional=True + ) + self.fc4 = FullyConnected(n_hidden, n_hidden, dropout) + self.out = torch.nn.Linear(n_hidden, n_class) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Tensor of dimension (batch, channel, time, feature). + Returns: + Tensor: Predictor tensor of dimension (batch, time, class). + """ + # N x C x T x F + x = self.fc1(x) + # N x C x T x H + x = self.fc2(x) + # N x C x T x H + x = self.fc3(x) + # N x C x T x H + x = x.squeeze(1) + # N x T x H + x = x.transpose(0, 1) + # T x N x H + x, _ = self.bi_rnn(x) + # The fifth (non-recurrent) layer takes both the forward and backward units as inputs + x = x[:, :, :self.n_hidden] + x[:, :, self.n_hidden:] + # T x N x H + x = self.fc4(x) + # T x N x H + x = self.out(x) + # T x N x n_class + x = x.permute(1, 0, 2) + # N x T x n_class + x = torch.nn.functional.log_softmax(x, dim=2) + # N x T x n_class + return x diff --git a/torchaudio/models/tacotron2.py b/torchaudio/models/tacotron2.py new file mode 100644 index 0000000000000000000000000000000000000000..109d396a7f15ee58cb650ce9e582e4eec3ebf8e2 --- /dev/null +++ b/torchaudio/models/tacotron2.py @@ -0,0 +1,1109 @@ +# ***************************************************************************** +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the +# names of its contributors may be used to endorse or promote products +# derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# ***************************************************************************** + +import warnings +from math import sqrt +from typing import Tuple, List, Optional, Union + +import torch +from torch import nn +from torch import Tensor +from torch.nn import functional as F + + +__all__ = [ + "Tacotron2", +] + + +def _get_linear_layer( + in_dim: int, out_dim: int, bias: bool = True, w_init_gain: str = "linear" +) -> torch.nn.Linear: + r"""Linear layer with xavier uniform initialization. + + Args: + in_dim (int): Size of each input sample. + out_dim (int): Size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias. (Default: ``True``) + w_init_gain (str, optional): Parameter passed to ``torch.nn.init.calculate_gain`` + for setting the gain parameter of ``xavier_uniform_``. (Default: ``linear``) + + Returns: + (torch.nn.Linear): The corresponding linear layer. + """ + linear = torch.nn.Linear(in_dim, out_dim, bias=bias) + torch.nn.init.xavier_uniform_( + linear.weight, gain=torch.nn.init.calculate_gain(w_init_gain) + ) + return linear + + +def _get_conv1d_layer( + in_channels: int, + out_channels: int, + kernel_size: int = 1, + stride: int = 1, + padding: Optional[Union[str, int, Tuple[int]]] = None, + dilation: int = 1, + bias: bool = True, + w_init_gain: str = "linear", +) -> torch.nn.Conv1d: + r"""1D convolution with xavier uniform initialization. + + Args: + in_channels (int): Number of channels in the input image. + out_channels (int): Number of channels produced by the convolution. + kernel_size (int, optional): Number of channels in the input image. (Default: ``1``) + stride (int, optional): Number of channels in the input image. (Default: ``1``) + padding (str, int or tuple, optional): Padding added to both sides of the input. + (Default: dilation * (kernel_size - 1) / 2) + dilation (int, optional): Number of channels in the input image. (Default: ``1``) + w_init_gain (str, optional): Parameter passed to ``torch.nn.init.calculate_gain`` + for setting the gain parameter of ``xavier_uniform_``. (Default: ``linear``) + + Returns: + (torch.nn.Conv1d): The corresponding Conv1D layer. + """ + if padding is None: + assert kernel_size % 2 == 1 + padding = int(dilation * (kernel_size - 1) / 2) + + conv1d = torch.nn.Conv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + + torch.nn.init.xavier_uniform_( + conv1d.weight, gain=torch.nn.init.calculate_gain(w_init_gain) + ) + + return conv1d + + +def _get_mask_from_lengths(lengths: Tensor) -> Tensor: + r"""Returns a binary mask based on ``lengths``. The ``i``-th row and ``j``-th column of the mask + is ``1`` if ``j`` is smaller than ``i``-th element of ``lengths. + + Args: + lengths (Tensor): The length of each element in the batch, with shape (n_batch, ). + + Returns: + mask (Tensor): The binary mask, with shape (n_batch, max of ``lengths``). + """ + max_len = torch.max(lengths).item() + ids = torch.arange(0, max_len, device=lengths.device, dtype=lengths.dtype) + mask = (ids < lengths.unsqueeze(1)).byte() + mask = torch.le(mask, 0) + return mask + + +class _LocationLayer(nn.Module): + r"""Location layer used in the Attention model. + + Args: + attention_n_filter (int): Number of filters for attention model. + attention_kernel_size (int): Kernel size for attention model. + attention_hidden_dim (int): Dimension of attention hidden representation. + """ + + def __init__( + self, + attention_n_filter: int, + attention_kernel_size: int, + attention_hidden_dim: int, + ): + super().__init__() + padding = int((attention_kernel_size - 1) / 2) + self.location_conv = _get_conv1d_layer( + 2, + attention_n_filter, + kernel_size=attention_kernel_size, + padding=padding, + bias=False, + stride=1, + dilation=1, + ) + self.location_dense = _get_linear_layer( + attention_n_filter, attention_hidden_dim, bias=False, w_init_gain="tanh" + ) + + def forward(self, attention_weights_cat: Tensor) -> Tensor: + r"""Location layer used in the Attention model. + + Args: + attention_weights_cat (Tensor): Cumulative and previous attention weights + with shape (n_batch, 2, max of ``text_lengths``). + + Returns: + processed_attention (Tensor): Cumulative and previous attention weights + with shape (n_batch, ``attention_hidden_dim``). + """ + # (n_batch, attention_n_filter, text_lengths.max()) + processed_attention = self.location_conv(attention_weights_cat) + processed_attention = processed_attention.transpose(1, 2) + # (n_batch, text_lengths.max(), attention_hidden_dim) + processed_attention = self.location_dense(processed_attention) + return processed_attention + + +class _Attention(nn.Module): + r"""Locally sensitive attention model. + + Args: + attention_rnn_dim (int): Number of hidden units for RNN. + encoder_embedding_dim (int): Number of embedding dimensions in the Encoder. + attention_hidden_dim (int): Dimension of attention hidden representation. + attention_location_n_filter (int): Number of filters for Attention model. + attention_location_kernel_size (int): Kernel size for Attention model. + """ + + def __init__( + self, + attention_rnn_dim: int, + encoder_embedding_dim: int, + attention_hidden_dim: int, + attention_location_n_filter: int, + attention_location_kernel_size: int, + ) -> None: + super().__init__() + self.query_layer = _get_linear_layer( + attention_rnn_dim, attention_hidden_dim, bias=False, w_init_gain="tanh" + ) + self.memory_layer = _get_linear_layer( + encoder_embedding_dim, attention_hidden_dim, bias=False, w_init_gain="tanh" + ) + self.v = _get_linear_layer(attention_hidden_dim, 1, bias=False) + self.location_layer = _LocationLayer( + attention_location_n_filter, + attention_location_kernel_size, + attention_hidden_dim, + ) + self.score_mask_value = -float("inf") + + def _get_alignment_energies( + self, query: Tensor, processed_memory: Tensor, attention_weights_cat: Tensor + ) -> Tensor: + r"""Get the alignment vector. + + Args: + query (Tensor): Decoder output with shape (n_batch, n_mels * n_frames_per_step). + processed_memory (Tensor): Processed Encoder outputs + with shape (n_batch, max of ``text_lengths``, attention_hidden_dim). + attention_weights_cat (Tensor): Cumulative and previous attention weights + with shape (n_batch, 2, max of ``text_lengths``). + + Returns: + alignment (Tensor): attention weights, it is a tensor with shape (batch, max of ``text_lengths``). + """ + + processed_query = self.query_layer(query.unsqueeze(1)) + processed_attention_weights = self.location_layer(attention_weights_cat) + energies = self.v( + torch.tanh(processed_query + processed_attention_weights + processed_memory) + ) + + alignment = energies.squeeze(2) + return alignment + + def forward( + self, + attention_hidden_state: Tensor, + memory: Tensor, + processed_memory: Tensor, + attention_weights_cat: Tensor, + mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + r"""Pass the input through the Attention model. + + Args: + attention_hidden_state (Tensor): Attention rnn last output with shape (n_batch, ``attention_rnn_dim``). + memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``). + processed_memory (Tensor): Processed Encoder outputs + with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``). + attention_weights_cat (Tensor): Previous and cumulative attention weights + with shape (n_batch, current_num_frames * 2, max of ``text_lengths``). + mask (Tensor): Binary mask for padded data with shape (n_batch, current_num_frames). + + Returns: + attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``). + attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``). + """ + alignment = self._get_alignment_energies( + attention_hidden_state, processed_memory, attention_weights_cat + ) + + alignment = alignment.masked_fill(mask, self.score_mask_value) + + attention_weights = F.softmax(alignment, dim=1) + attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) + attention_context = attention_context.squeeze(1) + + return attention_context, attention_weights + + +class _Prenet(nn.Module): + r"""Prenet Module. It is consists of ``len(output_size)`` linear layers. + + Args: + in_dim (int): The size of each input sample. + output_sizes (list): The output dimension of each linear layers. + """ + + def __init__(self, in_dim: int, out_sizes: List[int]) -> None: + super().__init__() + in_sizes = [in_dim] + out_sizes[:-1] + self.layers = nn.ModuleList( + [ + _get_linear_layer(in_size, out_size, bias=False) + for (in_size, out_size) in zip(in_sizes, out_sizes) + ] + ) + + def forward(self, x: Tensor) -> Tensor: + r"""Pass the input through Prenet. + + Args: + x (Tensor): The input sequence to Prenet with shape (n_batch, in_dim). + + Return: + x (Tensor): Tensor with shape (n_batch, sizes[-1]) + """ + + for linear in self.layers: + x = F.dropout(F.relu(linear(x)), p=0.5, training=True) + return x + + +class _Postnet(nn.Module): + r"""Postnet Module. + + Args: + n_mels (int): Number of mel bins. + postnet_embedding_dim (int): Postnet embedding dimension. + postnet_kernel_size (int): Postnet kernel size. + postnet_n_convolution (int): Number of postnet convolutions. + """ + + def __init__( + self, + n_mels: int, + postnet_embedding_dim: int, + postnet_kernel_size: int, + postnet_n_convolution: int, + ): + super().__init__() + self.convolutions = nn.ModuleList() + + for i in range(postnet_n_convolution): + in_channels = n_mels if i == 0 else postnet_embedding_dim + out_channels = n_mels if i == (postnet_n_convolution - 1) else postnet_embedding_dim + init_gain = "linear" if i == (postnet_n_convolution - 1) else "tanh" + num_features = n_mels if i == (postnet_n_convolution - 1) else postnet_embedding_dim + self.convolutions.append( + nn.Sequential( + _get_conv1d_layer( + in_channels, + out_channels, + kernel_size=postnet_kernel_size, + stride=1, + padding=int((postnet_kernel_size - 1) / 2), + dilation=1, + w_init_gain=init_gain, + ), + nn.BatchNorm1d(num_features), + ) + ) + + self.n_convs = len(self.convolutions) + + def forward(self, x: Tensor) -> Tensor: + r"""Pass the input through Postnet. + + Args: + x (Tensor): The input sequence with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``). + + Return: + x (Tensor): Tensor with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``). + """ + + for i, conv in enumerate(self.convolutions): + if i < self.n_convs - 1: + x = F.dropout(torch.tanh(conv(x)), 0.5, training=self.training) + else: + x = F.dropout(conv(x), 0.5, training=self.training) + + return x + + +class _Encoder(nn.Module): + r"""Encoder Module. + + Args: + encoder_embedding_dim (int): Number of embedding dimensions in the encoder. + encoder_n_convolution (int): Number of convolution layers in the encoder. + encoder_kernel_size (int): The kernel size in the encoder. + + Examples + >>> encoder = _Encoder(3, 512, 5) + >>> input = torch.rand(10, 20, 30) + >>> output = encoder(input) # shape: (10, 30, 512) + """ + + def __init__( + self, + encoder_embedding_dim: int, + encoder_n_convolution: int, + encoder_kernel_size: int, + ) -> None: + super().__init__() + + self.convolutions = nn.ModuleList() + for _ in range(encoder_n_convolution): + conv_layer = nn.Sequential( + _get_conv1d_layer( + encoder_embedding_dim, + encoder_embedding_dim, + kernel_size=encoder_kernel_size, + stride=1, + padding=int((encoder_kernel_size - 1) / 2), + dilation=1, + w_init_gain="relu", + ), + nn.BatchNorm1d(encoder_embedding_dim), + ) + self.convolutions.append(conv_layer) + + self.lstm = nn.LSTM( + encoder_embedding_dim, + int(encoder_embedding_dim / 2), + 1, + batch_first=True, + bidirectional=True, + ) + self.lstm.flatten_parameters() + + def forward(self, x: Tensor, input_lengths: Tensor) -> Tensor: + r"""Pass the input through the Encoder. + + Args: + x (Tensor): The input sequences with shape (n_batch, encoder_embedding_dim, n_seq). + input_lengths (Tensor): The length of each input sequence with shape (n_batch, ). + + Return: + x (Tensor): A tensor with shape (n_batch, n_seq, encoder_embedding_dim). + """ + + for conv in self.convolutions: + x = F.dropout(F.relu(conv(x)), 0.5, self.training) + + x = x.transpose(1, 2) + + input_lengths = input_lengths.cpu() + x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True) + + outputs, _ = self.lstm(x) + outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) + + return outputs + + +class _Decoder(nn.Module): + r"""Decoder with Attention model. + + Args: + n_mels (int): number of mel bins + n_frames_per_step (int): number of frames processed per step, only 1 is supported + encoder_embedding_dim (int): the number of embedding dimensions in the encoder. + decoder_rnn_dim (int): number of units in decoder LSTM + decoder_max_step (int): maximum number of output mel spectrograms + decoder_dropout (float): dropout probability for decoder LSTM + decoder_early_stopping (bool): stop decoding when all samples are finished + attention_rnn_dim (int): number of units in attention LSTM + attention_hidden_dim (int): dimension of attention hidden representation + attention_location_n_filter (int): number of filters for attention model + attention_location_kernel_size (int): kernel size for attention model + attention_dropout (float): dropout probability for attention LSTM + prenet_dim (int): number of ReLU units in prenet layers + gate_threshold (float): probability threshold for stop token + """ + + def __init__( + self, + n_mels: int, + n_frames_per_step: int, + encoder_embedding_dim: int, + decoder_rnn_dim: int, + decoder_max_step: int, + decoder_dropout: float, + decoder_early_stopping: bool, + attention_rnn_dim: int, + attention_hidden_dim: int, + attention_location_n_filter: int, + attention_location_kernel_size: int, + attention_dropout: float, + prenet_dim: int, + gate_threshold: float, + ) -> None: + + super().__init__() + self.n_mels = n_mels + self.n_frames_per_step = n_frames_per_step + self.encoder_embedding_dim = encoder_embedding_dim + self.attention_rnn_dim = attention_rnn_dim + self.decoder_rnn_dim = decoder_rnn_dim + self.prenet_dim = prenet_dim + self.decoder_max_step = decoder_max_step + self.gate_threshold = gate_threshold + self.attention_dropout = attention_dropout + self.decoder_dropout = decoder_dropout + self.decoder_early_stopping = decoder_early_stopping + + self.prenet = _Prenet(n_mels * n_frames_per_step, [prenet_dim, prenet_dim]) + + self.attention_rnn = nn.LSTMCell( + prenet_dim + encoder_embedding_dim, attention_rnn_dim + ) + + self.attention_layer = _Attention( + attention_rnn_dim, + encoder_embedding_dim, + attention_hidden_dim, + attention_location_n_filter, + attention_location_kernel_size, + ) + + self.decoder_rnn = nn.LSTMCell( + attention_rnn_dim + encoder_embedding_dim, decoder_rnn_dim, True + ) + + self.linear_projection = _get_linear_layer( + decoder_rnn_dim + encoder_embedding_dim, n_mels * n_frames_per_step + ) + + self.gate_layer = _get_linear_layer( + decoder_rnn_dim + encoder_embedding_dim, 1, bias=True, w_init_gain="sigmoid" + ) + + def _get_initial_frame(self, memory: Tensor) -> Tensor: + r"""Gets all zeros frames to use as the first decoder input. + + Args: + memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``). + + Returns: + decoder_input (Tensor): all zeros frames with shape + (n_batch, max of ``text_lengths``, ``n_mels * n_frames_per_step``). + """ + + n_batch = memory.size(0) + dtype = memory.dtype + device = memory.device + decoder_input = torch.zeros( + n_batch, self.n_mels * self.n_frames_per_step, dtype=dtype, device=device + ) + return decoder_input + + def _initialize_decoder_states( + self, memory: Tensor + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + r"""Initializes attention rnn states, decoder rnn states, attention + weights, attention cumulative weights, attention context, stores memory + and stores processed memory. + + Args: + memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``). + + Returns: + attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``). + attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``). + decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``). + decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``). + attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``). + attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``). + attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``). + processed_memory (Tensor): Processed encoder outputs + with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``). + """ + n_batch = memory.size(0) + max_time = memory.size(1) + dtype = memory.dtype + device = memory.device + + attention_hidden = torch.zeros( + n_batch, self.attention_rnn_dim, dtype=dtype, device=device + ) + attention_cell = torch.zeros( + n_batch, self.attention_rnn_dim, dtype=dtype, device=device + ) + + decoder_hidden = torch.zeros( + n_batch, self.decoder_rnn_dim, dtype=dtype, device=device + ) + decoder_cell = torch.zeros( + n_batch, self.decoder_rnn_dim, dtype=dtype, device=device + ) + + attention_weights = torch.zeros(n_batch, max_time, dtype=dtype, device=device) + attention_weights_cum = torch.zeros( + n_batch, max_time, dtype=dtype, device=device + ) + attention_context = torch.zeros( + n_batch, self.encoder_embedding_dim, dtype=dtype, device=device + ) + + processed_memory = self.attention_layer.memory_layer(memory) + + return ( + attention_hidden, + attention_cell, + decoder_hidden, + decoder_cell, + attention_weights, + attention_weights_cum, + attention_context, + processed_memory, + ) + + def _parse_decoder_inputs(self, decoder_inputs: Tensor) -> Tensor: + r"""Prepares decoder inputs. + + Args: + decoder_inputs (Tensor): Inputs used for teacher-forced training, i.e. mel-specs, + with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``) + + Returns: + inputs (Tensor): Processed decoder inputs with shape (max of ``mel_specgram_lengths``, n_batch, ``n_mels``). + """ + # (n_batch, n_mels, mel_specgram_lengths.max()) -> (n_batch, mel_specgram_lengths.max(), n_mels) + decoder_inputs = decoder_inputs.transpose(1, 2) + decoder_inputs = decoder_inputs.view( + decoder_inputs.size(0), + int(decoder_inputs.size(1) / self.n_frames_per_step), + -1, + ) + # (n_batch, mel_specgram_lengths.max(), n_mels) -> (mel_specgram_lengths.max(), n_batch, n_mels) + decoder_inputs = decoder_inputs.transpose(0, 1) + return decoder_inputs + + def _parse_decoder_outputs( + self, mel_specgram: Tensor, gate_outputs: Tensor, alignments: Tensor + ) -> Tuple[Tensor, Tensor, Tensor]: + r"""Prepares decoder outputs for output + + Args: + mel_specgram (Tensor): mel spectrogram with shape (max of ``mel_specgram_lengths``, n_batch, ``n_mels``) + gate_outputs (Tensor): predicted stop token with shape (max of ``mel_specgram_lengths``, n_batch) + alignments (Tensor): sequence of attention weights from the decoder + with shape (max of ``mel_specgram_lengths``, n_batch, max of ``text_lengths``) + + Returns: + mel_specgram (Tensor): mel spectrogram with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``) + gate_outputs (Tensor): predicted stop token with shape (n_batch, max of ``mel_specgram_lengths``) + alignments (Tensor): sequence of attention weights from the decoder + with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``) + """ + # (mel_specgram_lengths.max(), n_batch, text_lengths.max()) + # -> (n_batch, mel_specgram_lengths.max(), text_lengths.max()) + alignments = alignments.transpose(0, 1).contiguous() + # (mel_specgram_lengths.max(), n_batch) -> (n_batch, mel_specgram_lengths.max()) + gate_outputs = gate_outputs.transpose(0, 1).contiguous() + # (mel_specgram_lengths.max(), n_batch, n_mels) -> (n_batch, mel_specgram_lengths.max(), n_mels) + mel_specgram = mel_specgram.transpose(0, 1).contiguous() + # decouple frames per step + shape = (mel_specgram.shape[0], -1, self.n_mels) + mel_specgram = mel_specgram.view(*shape) + # (n_batch, mel_specgram_lengths.max(), n_mels) -> (n_batch, n_mels, T_out) + mel_specgram = mel_specgram.transpose(1, 2) + + return mel_specgram, gate_outputs, alignments + + def decode( + self, + decoder_input: Tensor, + attention_hidden: Tensor, + attention_cell: Tensor, + decoder_hidden: Tensor, + decoder_cell: Tensor, + attention_weights: Tensor, + attention_weights_cum: Tensor, + attention_context: Tensor, + memory: Tensor, + processed_memory: Tensor, + mask: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + r"""Decoder step using stored states, attention and memory + + Args: + decoder_input (Tensor): Output of the Prenet with shape (n_batch, ``prenet_dim``). + attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``). + attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``). + decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``). + decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``). + attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``). + attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``). + attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``). + memory (Tensor): Encoder output with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``). + processed_memory (Tensor): Processed Encoder outputs + with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``). + mask (Tensor): Binary mask for padded data with shape (n_batch, current_num_frames). + + Returns: + decoder_output: Predicted mel spectrogram for the current frame with shape (n_batch, ``n_mels``). + gate_prediction (Tensor): Prediction of the stop token with shape (n_batch, ``1``). + attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``). + attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``). + decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``). + decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``). + attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``). + attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``). + attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``). + """ + cell_input = torch.cat((decoder_input, attention_context), -1) + + attention_hidden, attention_cell = self.attention_rnn( + cell_input, (attention_hidden, attention_cell) + ) + attention_hidden = F.dropout( + attention_hidden, self.attention_dropout, self.training + ) + + attention_weights_cat = torch.cat( + (attention_weights.unsqueeze(1), attention_weights_cum.unsqueeze(1)), dim=1 + ) + attention_context, attention_weights = self.attention_layer( + attention_hidden, memory, processed_memory, attention_weights_cat, mask + ) + + attention_weights_cum += attention_weights + decoder_input = torch.cat((attention_hidden, attention_context), -1) + + decoder_hidden, decoder_cell = self.decoder_rnn( + decoder_input, (decoder_hidden, decoder_cell) + ) + decoder_hidden = F.dropout(decoder_hidden, self.decoder_dropout, self.training) + + decoder_hidden_attention_context = torch.cat( + (decoder_hidden, attention_context), dim=1 + ) + decoder_output = self.linear_projection(decoder_hidden_attention_context) + + gate_prediction = self.gate_layer(decoder_hidden_attention_context) + + return ( + decoder_output, + gate_prediction, + attention_hidden, + attention_cell, + decoder_hidden, + decoder_cell, + attention_weights, + attention_weights_cum, + attention_context, + ) + + def forward( + self, memory: Tensor, mel_specgram_truth: Tensor, memory_lengths: Tensor + ) -> Tuple[Tensor, Tensor, Tensor]: + r"""Decoder forward pass for training. + + Args: + memory (Tensor): Encoder outputs + with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``). + mel_specgram_truth (Tensor): Decoder ground-truth mel-specs for teacher forcing + with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``). + memory_lengths (Tensor): Encoder output lengths for attention masking + (the same as ``text_lengths``) with shape (n_batch, ). + + Returns: + mel_specgram (Tensor): Predicted mel spectrogram + with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``). + gate_outputs (Tensor): Predicted stop token for each timestep + with shape (n_batch, max of ``mel_specgram_lengths``). + alignments (Tensor): Sequence of attention weights from the decoder + with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``). + """ + + decoder_input = self._get_initial_frame(memory).unsqueeze(0) + decoder_inputs = self._parse_decoder_inputs(mel_specgram_truth) + decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) + decoder_inputs = self.prenet(decoder_inputs) + + mask = _get_mask_from_lengths(memory_lengths) + ( + attention_hidden, + attention_cell, + decoder_hidden, + decoder_cell, + attention_weights, + attention_weights_cum, + attention_context, + processed_memory, + ) = self._initialize_decoder_states(memory) + + mel_outputs, gate_outputs, alignments = [], [], [] + while len(mel_outputs) < decoder_inputs.size(0) - 1: + decoder_input = decoder_inputs[len(mel_outputs)] + ( + mel_output, + gate_output, + attention_hidden, + attention_cell, + decoder_hidden, + decoder_cell, + attention_weights, + attention_weights_cum, + attention_context, + ) = self.decode( + decoder_input, + attention_hidden, + attention_cell, + decoder_hidden, + decoder_cell, + attention_weights, + attention_weights_cum, + attention_context, + memory, + processed_memory, + mask, + ) + + mel_outputs += [mel_output.squeeze(1)] + gate_outputs += [gate_output.squeeze()] + alignments += [attention_weights] + + mel_specgram, gate_outputs, alignments = self._parse_decoder_outputs( + torch.stack(mel_outputs), torch.stack(gate_outputs), torch.stack(alignments) + ) + + return mel_specgram, gate_outputs, alignments + + def _get_go_frame(self, memory: Tensor) -> Tensor: + """Gets all zeros frames to use as the first decoder input + + args: + memory (Tensor): Encoder outputs + with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``). + + returns: + decoder_input (Tensor): All zeros frames with shape(n_batch, ``n_mels`` * ``n_frame_per_step``). + """ + + n_batch = memory.size(0) + dtype = memory.dtype + device = memory.device + decoder_input = torch.zeros( + n_batch, self.n_mels * self.n_frames_per_step, dtype=dtype, device=device + ) + return decoder_input + + @torch.jit.export + def infer(self, + memory: Tensor, + memory_lengths: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Decoder inference + + Args: + memory (Tensor): Encoder outputs + with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``). + memory_lengths (Tensor): Encoder output lengths for attention masking + (the same as ``text_lengths``) with shape (n_batch, ). + + Returns: + mel_specgram (Tensor): Predicted mel spectrogram + with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``). + mel_specgram_lengths (Tensor): the length of the predicted mel spectrogram (n_batch, )) + gate_outputs (Tensor): Predicted stop token for each timestep + with shape (n_batch, max of ``mel_specgram_lengths``). + alignments (Tensor): Sequence of attention weights from the decoder + with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``). + """ + batch_size, device = memory.size(0), memory.device + + decoder_input = self._get_go_frame(memory) + + mask = _get_mask_from_lengths(memory_lengths) + ( + attention_hidden, + attention_cell, + decoder_hidden, + decoder_cell, + attention_weights, + attention_weights_cum, + attention_context, + processed_memory, + ) = self._initialize_decoder_states(memory) + + mel_specgram_lengths = torch.zeros([batch_size], dtype=torch.int32, device=device) + finished = torch.zeros([batch_size], dtype=torch.bool, device=device) + mel_specgrams: List[Tensor] = [] + gate_outputs: List[Tensor] = [] + alignments: List[Tensor] = [] + for _ in range(self.decoder_max_step): + decoder_input = self.prenet(decoder_input) + ( + mel_specgram, + gate_output, + attention_hidden, + attention_cell, + decoder_hidden, + decoder_cell, + attention_weights, + attention_weights_cum, + attention_context, + ) = self.decode( + decoder_input, + attention_hidden, + attention_cell, + decoder_hidden, + decoder_cell, + attention_weights, + attention_weights_cum, + attention_context, + memory, + processed_memory, + mask, + ) + + mel_specgrams.append(mel_specgram.unsqueeze(0)) + gate_outputs.append(gate_output.transpose(0, 1)) + alignments.append(attention_weights) + mel_specgram_lengths[~finished] += 1 + + finished |= torch.sigmoid(gate_output.squeeze(1)) > self.gate_threshold + if self.decoder_early_stopping and torch.all(finished): + break + + decoder_input = mel_specgram + + if len(mel_specgrams) == self.decoder_max_step: + warnings.warn( + "Reached max decoder steps. The generated spectrogram might not cover " + "the whole transcript.") + + mel_specgrams = torch.cat(mel_specgrams, dim=0) + gate_outputs = torch.cat(gate_outputs, dim=0) + alignments = torch.cat(alignments, dim=0) + + mel_specgrams, gate_outputs, alignments = self._parse_decoder_outputs( + mel_specgrams, gate_outputs, alignments + ) + + return mel_specgrams, mel_specgram_lengths, gate_outputs, alignments + + +class Tacotron2(nn.Module): + r"""Tacotron2 model based on the implementation from + `Nvidia `_. + + The original implementation was introduced in + *Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions* + [:footcite:`shen2018natural`]. + + Args: + mask_padding (bool, optional): Use mask padding (Default: ``False``). + n_mels (int, optional): Number of mel bins (Default: ``80``). + n_symbol (int, optional): Number of symbols for the input text (Default: ``148``). + n_frames_per_step (int, optional): Number of frames processed per step, only 1 is supported (Default: ``1``). + symbol_embedding_dim (int, optional): Input embedding dimension (Default: ``512``). + encoder_n_convolution (int, optional): Number of encoder convolutions (Default: ``3``). + encoder_kernel_size (int, optional): Encoder kernel size (Default: ``5``). + encoder_embedding_dim (int, optional): Encoder embedding dimension (Default: ``512``). + decoder_rnn_dim (int, optional): Number of units in decoder LSTM (Default: ``1024``). + decoder_max_step (int, optional): Maximum number of output mel spectrograms (Default: ``2000``). + decoder_dropout (float, optional): Dropout probability for decoder LSTM (Default: ``0.1``). + decoder_early_stopping (bool, optional): Continue decoding after all samples are finished (Default: ``True``). + attention_rnn_dim (int, optional): Number of units in attention LSTM (Default: ``1024``). + attention_hidden_dim (int, optional): Dimension of attention hidden representation (Default: ``128``). + attention_location_n_filter (int, optional): Number of filters for attention model (Default: ``32``). + attention_location_kernel_size (int, optional): Kernel size for attention model (Default: ``31``). + attention_dropout (float, optional): Dropout probability for attention LSTM (Default: ``0.1``). + prenet_dim (int, optional): Number of ReLU units in prenet layers (Default: ``256``). + postnet_n_convolution (int, optional): Number of postnet convolutions (Default: ``5``). + postnet_kernel_size (int, optional): Postnet kernel size (Default: ``5``). + postnet_embedding_dim (int, optional): Postnet embedding dimension (Default: ``512``). + gate_threshold (float, optional): Probability threshold for stop token (Default: ``0.5``). + """ + + def __init__( + self, + mask_padding: bool = False, + n_mels: int = 80, + n_symbol: int = 148, + n_frames_per_step: int = 1, + symbol_embedding_dim: int = 512, + encoder_embedding_dim: int = 512, + encoder_n_convolution: int = 3, + encoder_kernel_size: int = 5, + decoder_rnn_dim: int = 1024, + decoder_max_step: int = 2000, + decoder_dropout: float = 0.1, + decoder_early_stopping: bool = True, + attention_rnn_dim: int = 1024, + attention_hidden_dim: int = 128, + attention_location_n_filter: int = 32, + attention_location_kernel_size: int = 31, + attention_dropout: float = 0.1, + prenet_dim: int = 256, + postnet_n_convolution: int = 5, + postnet_kernel_size: int = 5, + postnet_embedding_dim: int = 512, + gate_threshold: float = 0.5, + ) -> None: + super().__init__() + + self.mask_padding = mask_padding + self.n_mels = n_mels + self.n_frames_per_step = n_frames_per_step + self.embedding = nn.Embedding(n_symbol, symbol_embedding_dim) + std = sqrt(2.0 / (n_symbol + symbol_embedding_dim)) + val = sqrt(3.0) * std + self.embedding.weight.data.uniform_(-val, val) + self.encoder = _Encoder( + encoder_embedding_dim, encoder_n_convolution, encoder_kernel_size + ) + self.decoder = _Decoder( + n_mels, + n_frames_per_step, + encoder_embedding_dim, + decoder_rnn_dim, + decoder_max_step, + decoder_dropout, + decoder_early_stopping, + attention_rnn_dim, + attention_hidden_dim, + attention_location_n_filter, + attention_location_kernel_size, + attention_dropout, + prenet_dim, + gate_threshold, + ) + self.postnet = _Postnet( + n_mels, postnet_embedding_dim, postnet_kernel_size, postnet_n_convolution + ) + + def forward( + self, + tokens: Tensor, + token_lengths: Tensor, + mel_specgram: Tensor, + mel_specgram_lengths: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + r"""Pass the input through the Tacotron2 model. This is in teacher + forcing mode, which is generally used for training. + + The input ``tokens`` should be padded with zeros to length max of ``token_lengths``. + The input ``mel_specgram`` should be padded with zeros to length max of ``mel_specgram_lengths``. + + Args: + tokens (Tensor): The input tokens to Tacotron2 with shape `(n_batch, max of token_lengths)`. + token_lengths (Tensor): The valid length of each sample in ``tokens`` with shape `(n_batch, )`. + mel_specgram (Tensor): The target mel spectrogram + with shape `(n_batch, n_mels, max of mel_specgram_lengths)`. + mel_specgram_lengths (Tensor): The length of each mel spectrogram with shape `(n_batch, )`. + + Returns: + [Tensor, Tensor, Tensor, Tensor]: + Tensor + Mel spectrogram before Postnet with shape `(n_batch, n_mels, max of mel_specgram_lengths)`. + Tensor + Mel spectrogram after Postnet with shape `(n_batch, n_mels, max of mel_specgram_lengths)`. + Tensor + The output for stop token at each time step with shape `(n_batch, max of mel_specgram_lengths)`. + Tensor + Sequence of attention weights from the decoder with + shape `(n_batch, max of mel_specgram_lengths, max of token_lengths)`. + """ + + embedded_inputs = self.embedding(tokens).transpose(1, 2) + + encoder_outputs = self.encoder(embedded_inputs, token_lengths) + mel_specgram, gate_outputs, alignments = self.decoder( + encoder_outputs, mel_specgram, memory_lengths=token_lengths + ) + + mel_specgram_postnet = self.postnet(mel_specgram) + mel_specgram_postnet = mel_specgram + mel_specgram_postnet + + if self.mask_padding: + mask = _get_mask_from_lengths(mel_specgram_lengths) + mask = mask.expand(self.n_mels, mask.size(0), mask.size(1)) + mask = mask.permute(1, 0, 2) + + mel_specgram.masked_fill_(mask, 0.0) + mel_specgram_postnet.masked_fill_(mask, 0.0) + gate_outputs.masked_fill_(mask[:, 0, :], 1e3) + + return mel_specgram, mel_specgram_postnet, gate_outputs, alignments + + @torch.jit.export + def infer(self, tokens: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]: + r"""Using Tacotron2 for inference. The input is a batch of encoded + sentences (``tokens``) and its corresponding lengths (``lengths``). The + output is the generated mel spectrograms, its corresponding lengths, and + the attention weights from the decoder. + + The input `tokens` should be padded with zeros to length max of ``lengths``. + + Args: + tokens (Tensor): The input tokens to Tacotron2 with shape `(n_batch, max of lengths)`. + lengths (Tensor or None, optional): + The valid length of each sample in ``tokens`` with shape `(n_batch, )`. + If ``None``, it is assumed that the all the tokens are valid. Default: ``None`` + + Returns: + (Tensor, Tensor, Tensor): + Tensor + The predicted mel spectrogram with shape `(n_batch, n_mels, max of mel_specgram_lengths)`. + Tensor + The length of the predicted mel spectrogram with shape `(n_batch, )`. + Tensor + Sequence of attention weights from the decoder with shape + `(n_batch, max of mel_specgram_lengths, max of lengths)`. + """ + n_batch, max_length = tokens.shape + if lengths is None: + lengths = torch.tensor([max_length]).expand(n_batch).to(tokens.device, tokens.dtype) + + assert lengths is not None # For TorchScript compiler + + embedded_inputs = self.embedding(tokens).transpose(1, 2) + encoder_outputs = self.encoder(embedded_inputs, lengths) + mel_specgram, mel_specgram_lengths, _, alignments = self.decoder.infer( + encoder_outputs, lengths + ) + + mel_outputs_postnet = self.postnet(mel_specgram) + mel_outputs_postnet = mel_specgram + mel_outputs_postnet + + alignments = alignments.unfold(1, n_batch, n_batch).transpose(0, 2) + + return mel_outputs_postnet, mel_specgram_lengths, alignments diff --git a/torchaudio/models/wav2letter.py b/torchaudio/models/wav2letter.py new file mode 100644 index 0000000000000000000000000000000000000000..4d93e74392291324ab670ca7524182850edef589 --- /dev/null +++ b/torchaudio/models/wav2letter.py @@ -0,0 +1,74 @@ +from torch import Tensor +from torch import nn + +__all__ = [ + "Wav2Letter", +] + + +class Wav2Letter(nn.Module): + r"""Wav2Letter model architecture from *Wav2Letter: an End-to-End ConvNet-based Speech + Recognition System* [:footcite:`collobert2016wav2letter`]. + + :math:`\text{padding} = \frac{\text{ceil}(\text{kernel} - \text{stride})}{2}` + + Args: + num_classes (int, optional): Number of classes to be classified. (Default: ``40``) + input_type (str, optional): Wav2Letter can use as input: ``waveform``, ``power_spectrum`` + or ``mfcc`` (Default: ``waveform``). + num_features (int, optional): Number of input features that the network will receive (Default: ``1``). + """ + + def __init__(self, num_classes: int = 40, + input_type: str = "waveform", + num_features: int = 1) -> None: + super(Wav2Letter, self).__init__() + + acoustic_num_features = 250 if input_type == "waveform" else num_features + acoustic_model = nn.Sequential( + nn.Conv1d(in_channels=acoustic_num_features, out_channels=250, kernel_size=48, stride=2, padding=23), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=2000, kernel_size=32, stride=1, padding=16), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=2000, out_channels=2000, kernel_size=1, stride=1, padding=0), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=2000, out_channels=num_classes, kernel_size=1, stride=1, padding=0), + nn.ReLU(inplace=True) + ) + + if input_type == "waveform": + waveform_model = nn.Sequential( + nn.Conv1d(in_channels=num_features, out_channels=250, kernel_size=250, stride=160, padding=45), + nn.ReLU(inplace=True) + ) + self.acoustic_model = nn.Sequential(waveform_model, acoustic_model) + + if input_type in ["power_spectrum", "mfcc"]: + self.acoustic_model = acoustic_model + + def forward(self, x: Tensor) -> Tensor: + r""" + Args: + x (torch.Tensor): Tensor of dimension (batch_size, num_features, input_length). + + Returns: + Tensor: Predictor tensor of dimension (batch_size, number_of_classes, input_length). + """ + + x = self.acoustic_model(x) + x = nn.functional.log_softmax(x, dim=1) + return x diff --git a/torchaudio/models/wav2vec2/__init__.py b/torchaudio/models/wav2vec2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f0538a4a3f8911ff216c2b86c9e5f5f28b445eb7 --- /dev/null +++ b/torchaudio/models/wav2vec2/__init__.py @@ -0,0 +1,23 @@ +from .model import ( + Wav2Vec2Model, + wav2vec2_model, + wav2vec2_base, + wav2vec2_large, + wav2vec2_large_lv60k, + hubert_base, + hubert_large, + hubert_xlarge, +) +from . import utils + +__all__ = [ + 'Wav2Vec2Model', + 'wav2vec2_model', + 'wav2vec2_base', + 'wav2vec2_large', + 'wav2vec2_large_lv60k', + 'hubert_base', + 'hubert_large', + 'hubert_xlarge', + 'utils', +] diff --git a/torchaudio/models/wav2vec2/components.py b/torchaudio/models/wav2vec2/components.py new file mode 100644 index 0000000000000000000000000000000000000000..7093fc9de590e2221e45770f356e8169ddb2e268 --- /dev/null +++ b/torchaudio/models/wav2vec2/components.py @@ -0,0 +1,717 @@ +import logging +from typing import Optional, Tuple, List + +import torch +from torch import Tensor, nn +from torch.nn import Module + +_LG = logging.getLogger(__name__) + + +class LayerNorm(nn.LayerNorm): + """Layer norm with transpose""" + def forward(self, input: Tensor) -> Tensor: + x = input.transpose(-2, -1) + x = nn.functional.layer_norm( + x, self.normalized_shape, self.weight, self.bias, self.eps) + x = x.transpose(-2, -1) + return x + + +class ConvLayerBlock(Module): + """Convolution unit of FeatureExtractor""" + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + bias: bool, + layer_norm: Optional[Module], + ): + super().__init__() + self.kernel_size = kernel_size + self.stride = stride + self.layer_norm = layer_norm + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + bias=bias, + ) + + def forward( + self, + x: Tensor, + length: Optional[Tensor], + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x (Tensor): Shape: ``[batch, in_channels, in_frame]``. + length (Tensor or None, optional): Shape ``[batch, ]``. + Returns: + Tensor: Shape ``[batch, out_channels, out_frames]``. + Optional[Tensor]: Shape ``[batch, ]``. + """ + x = self.conv(x) + if self.layer_norm is not None: + x = self.layer_norm(x) + x = nn.functional.gelu(x) + + if length is not None: + length = torch.div(length - self.kernel_size, self.stride, rounding_mode='floor') + 1 + # When input length is 0, the resulting length can be negative. So fix it here. + length = torch.max(torch.zeros_like(length), length) + return x, length + + +class FeatureExtractor(Module): + """Extract features from audio + + Args: + conv_layers (nn.ModuleList): + convolution layers + """ + def __init__( + self, + conv_layers: nn.ModuleList, + ): + super().__init__() + self.conv_layers = conv_layers + + def forward( + self, + x: Tensor, + length: Optional[Tensor], + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x (Tensor): + Input Tensor representing a batch of audio, + shape: ``[batch, time]``. + length (Tensor or None, optional): + Valid length of each input sample. shape: ``[batch, ]``. + + Returns: + Tensor: + The resulting feature, shape: ``[batch, frame, feature]`` + Optional[Tensor]: + Valid length of each output sample. shape: ``[batch, ]``. + """ + if x.ndim != 2: + raise ValueError( + "Expected the input Tensor to be 2D (batch, time), " + "but received {list(x.shape)}") + + x = x.unsqueeze(1) # (batch, channel==1, frame) + for layer in self.conv_layers: + x, length = layer(x, length) # (batch, feature, frame) + x = x.transpose(1, 2) # (batch, frame, feature) + return x, length + + +class FeatureProjection(Module): + """Layer that connects FeatureExtractor and Encoder + + Projects features to encoder dimension. + + Args: + in_features (int): Input feature dim. + out_features (int): Output feature dim. + dropout (float): Dropout probability. + """ + def __init__( + self, + in_features: int, + out_features: int, + dropout: float, + ): + super().__init__() + self.layer_norm = nn.LayerNorm(in_features) + self.projection = nn.Linear(in_features, out_features,) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + """ + Args: + x (Tensor): + Feature Tensor. shape: ``[batch, frame, in_feature]`` + Returns: + Tensor: Projected features. ``[batch, frame, out_feature]``. + """ + x = self.layer_norm(x) + x = self.projection(x) + x = self.dropout(x) + return x + + +class ConvolutionalPositionalEmbedding(Module): + """Positional embedding which is placed at the beginning of Transformer. + + Args: + embed_dim (int): Feature dimension of the input Tensor. + kernel_size (int): The number of frames to be use. + groups (int): The number of groups in feature dimensions. + """ + def __init__( + self, + embed_dim: int, + kernel_size: int, + groups: int, + ): + super().__init__() + self.embed_dim = embed_dim + self.conv = nn.Conv1d( + in_channels=embed_dim, + out_channels=embed_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + groups=groups, + ) + self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) + self.num_remove: int = 1 if kernel_size % 2 == 0 else 0 + + def __prepare_scriptable__(self): + for hook in self.conv._forward_pre_hooks.values(): + # The hook we want to remove is an instance of WeightNorm class, so + # normally we would do `if isinstance(...)` but this class is not accessible + # because of shadowing, so we check the module name directly. + # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3 + if ( + hook.__module__ == 'torch.nn.utils.weight_norm' and + hook.__class__.__name__ == 'WeightNorm' + ): + _LG.warning('Removing weight_norm from %s', self.__class__.__name__) + torch.nn.utils.remove_weight_norm(self.conv) + return self + + def forward(self, x): + """ + Args: + x (Tensor): shape ``[batch, frame, feature]``. + + Returns: + Tensor: The resulting feature. Shape ``[batch, frame, feature]``. + """ + x = x.transpose(-2, -1) + x = self.conv(x) + if self.num_remove > 0: + x = x[..., :-self.num_remove] + x = torch.nn.functional.gelu(x) + x = x.transpose(-2, -1) + return x + + +class SelfAttention(Module): + """Multihead Self Attention module + + Args: + embed_dim (int): Total dimension of the model. + num_heads (int): The number of heads. + dropout (float, optional): + Dropout probabiliry on attn_output_weights. Default: ``0.0`` + """ + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + ): + super().__init__() + head_dim = embed_dim // num_heads + if head_dim * num_heads != embed_dim: + raise ValueError(f"`embed_dim ({embed_dim})` is not divisible by `num_heads ({num_heads})`") + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = head_dim + + self.scaling = self.head_dim ** -0.5 + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + + def forward( + self, + x: Tensor, + attention_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Args: + x (Tensor): shape: ``[batch_size, sequence_length, embed_dim]``. + attention_mask (Tensor or None, optional): + shape: ``[batch_size, 1, sequence_length, sequence_length]`` + + Returns: + Tensor: The resulting tensor. shape: ``[batch, sequence_length, embed_dim]`` + """ + if x.ndim != 3 or x.shape[2] != self.embed_dim: + raise ValueError( + f"The expected input shape is (batch, sequence, embed_dim=={self.embed_dim}). " + f"Found {x.shape}." + ) + batch_size, length, embed_dim = x.size() + if attention_mask is not None: + shape_ = (batch_size, 1, length, length) + if attention_mask.size() != shape_: + raise ValueError( + f"The expected attention mask shape is {shape_}. " + f"Found {attention_mask.size()}." + ) + + shape = (batch_size, length, self.num_heads, self.head_dim) + q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd + k = self.k_proj(x).view(*shape).permute(0, 2, 3, 1) # B, nH, Hd, L + v = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd + + weights = self.scaling * (q @ k) # B, nH, L, L + if attention_mask is not None: + weights += attention_mask + + weights = torch.nn.functional.softmax(weights, dim=-1) + weights = torch.nn.functional.dropout(weights, p=self.dropout, training=self.training) + + output = weights @ v # B, nH, L, Hd + output = output.transpose(2, 1).reshape(batch_size, length, embed_dim) + + output = self.out_proj(output) + return output + + +class FeedForward(Module): + """Layer that follows attention layer in encoder layer. + """ + def __init__( + self, + io_features: int, + intermediate_features: int, + intermediate_dropout: float, + output_dropout: float, + ): + super().__init__() + self.intermediate_dense = nn.Linear(io_features, intermediate_features) + self.intermediate_dropout = nn.Dropout(intermediate_dropout) + self.output_dense = nn.Linear(intermediate_features, io_features) + self.output_dropout = nn.Dropout(output_dropout) + + def forward(self, x): + """ + Args: + x (Tensor): shape: `(batch, sequence_length, io_features)` + Returns: + x (Tensor): shape: `(batch, sequence_length, io_features)` + """ + x = self.intermediate_dense(x) + x = torch.nn.functional.gelu(x) + x = self.intermediate_dropout(x) + + x = self.output_dense(x) + x = self.output_dropout(x) + return x + + +class EncoderLayer(Module): + """A layer unit in encoder. Combines multihead self attention and feed forward. + """ + def __init__( + self, + attention: Module, + dropout: float, + layer_norm_first: bool, + feed_forward: Module, + ): + super().__init__() + self.attention = attention + self.dropout = nn.Dropout(dropout) + self.layer_norm = nn.LayerNorm(attention.embed_dim) + self.layer_norm_first = layer_norm_first + self.feed_forward = feed_forward + self.final_layer_norm = nn.LayerNorm(attention.embed_dim) + + def forward( + self, + x: Tensor, + attention_mask: Optional[Tensor] = None, + ): + """ + Args: + x (Tensor): shape: `(batch, sequence_length, embed_dim)` + attention_mask (Tensor or None, optional): + shape: `(batch, 1, sequence_length, sequence_length)` + """ + residual = x + + if self.layer_norm_first: + x = self.layer_norm(x) + + x = self.attention(x, attention_mask) + x = self.dropout(x) + x = residual + x + + if self.layer_norm_first: + x = x + self.feed_forward(self.final_layer_norm(x)) + else: + x = self.layer_norm(x) + x = self.final_layer_norm(x + self.feed_forward(x)) + return x + + +class Transformer(Module): + def __init__( + self, + pos_conv_embed: Module, + dropout: float, + layers: Module, + layer_norm_first: bool, + layer_drop: float, + ): + super().__init__() + self.pos_conv_embed = pos_conv_embed + self.layer_norm = nn.LayerNorm(pos_conv_embed.embed_dim) + self.layer_norm_first = layer_norm_first + self.layer_drop = layer_drop + self.dropout = nn.Dropout(dropout) + self.layers = layers + + def _preprocess(self, x: Tensor): + x = x + self.pos_conv_embed(x) + + if self.layer_norm_first: + x = self.layer_norm(x) + + x = self.dropout(x) + return x + + def forward( + self, + x: Tensor, + attention_mask: Optional[Tensor] = None, + ): + x = self._preprocess(x) + for layer in self.layers: + if not (self.training and torch.rand(1).item() <= self.layer_drop): + x = layer(x, attention_mask) + + if not self.layer_norm_first: + x = self.layer_norm(x) + + return x + + def get_intermediate_outputs( + self, + x: Tensor, + attention_mask: Optional[Tensor] = None, + num_layers: Optional[int] = None, + ) -> List[Tensor]: + if num_layers is not None: + if not 0 < num_layers <= len(self.layers): + raise ValueError(f'`num_layers` must be between [1, {len(self.layers)}]') + + ret: List[Tensor] = [] + x = self._preprocess(x) + for layer in self.layers: + x = layer(x, attention_mask) + ret.append(x) + if num_layers is not None and len(ret) >= num_layers: + return ret + return ret + + +class Encoder(Module): + def __init__( + self, + feature_projection: Module, + transformer: Module, + ): + super().__init__() + self.feature_projection = feature_projection + self.transformer = transformer + + def _preprocess( + self, + features: Tensor, + lengths: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + x = self.feature_projection(features) + + mask: Optional[Tensor] = None + if lengths is not None: + batch_size, max_len, _ = x.shape + # create mask for padded elements and zero-out them + mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None] + x[mask] = 0.0 + # extend the mask to attention shape and set weight + mask = -10000.0 * mask[:, None, None, :].to(dtype=features.dtype) + mask = mask.expand(batch_size, 1, max_len, max_len) + return x, mask + + def forward( + self, + features: Tensor, + lengths: Optional[Tensor] = None, + ) -> Tensor: + x, mask = self._preprocess(features, lengths) + x = self.transformer(x, attention_mask=mask) + return x + + def extract_features( + self, + features: Tensor, + lengths: Optional[Tensor] = None, + num_layers: Optional[int] = None, + ) -> List[Tensor]: + x, masks = self._preprocess(features, lengths) + return self.transformer.get_intermediate_outputs( + x, attention_mask=masks, num_layers=num_layers) + + +################################################################################ +def _get_feature_extractor( + norm_mode: str, + shapes: List[Tuple[int, int, int]], + bias: bool, +) -> FeatureExtractor: + """ + Args: + norm_mode (str): + Either "group_norm" or "layer_norm". + If "group_norm", then a single normalization is applied + in the first convolution block. Otherwise, all the convolution + blocks will have layer normalization. + This option corresponds to "extractor_mode" from fairseq. + Expected values are "group_norm" for Base arch, and + "layer_norm" for Large arch. + shapes (list of tuple of int): + Configuration of convolution layers. List of convolution configuration, + i.e. ``[(output_channel, kernel_size, stride), ...]`` + This option corresponds to "conv_feature_layers" from fairseq. + Expected values are + ``[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2`` + for all the architectures. + bias (bool): + Whether to include bias term to each convolution operation. + This option corresponds to "conv_bias" from fairseq. + Expected values are False for Base arch, and True for Large arch. + + See Also: + * Original implementation + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L666-L733 + * "extractor_mode" + - Def and base: + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L38-L45 + - Large: + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L52 + * "conv_feature_layers" + - Def, base and large: + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L94-L100 + * "conv_bias" + - Def and base: + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L101-L103 + - Large: + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L61 + """ + assert norm_mode in ["group_norm", "layer_norm"] + blocks = [] + in_channels = 1 + for i, (out_channels, kernel_size, stride) in enumerate(shapes): + normalization = None + if norm_mode == "group_norm" and i == 0: + normalization = nn.GroupNorm( + num_groups=out_channels, + num_channels=out_channels, + affine=True, + ) + elif norm_mode == "layer_norm": + normalization = LayerNorm( + normalized_shape=out_channels, + elementwise_affine=True, + ) + blocks.append( + ConvLayerBlock( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + bias=bias, + layer_norm=normalization, + ) + ) + in_channels = out_channels + return FeatureExtractor(nn.ModuleList(blocks)) + + +def _get_encoder( + in_features: int, + embed_dim: int, + dropout_input: float, + pos_conv_kernel: int, + pos_conv_groups: int, + num_layers: int, + num_heads: int, + attention_dropout: float, + ff_interm_features: int, + ff_interm_dropout: float, + dropout: float, + layer_norm_first: bool, + layer_drop: float, +) -> Encoder: + """ + Args: + in_features (int): The number of input features. + embed_dim (int): + The dimension of embedding. + This option corresponds to "encoder_embed_dim" from fairseq. + Expected values are 768 for Base arch, and 1024 for Large arch. + dropout_input (float): + The dropout probability applied after the input feature is projected + to ``embed_dim``. + This option corresponds to "dropout_input" from fairseq. + Expected values are 0.1 for both Base and Large arch. + pos_conv_kernel (int): + The kernel size of convolutional positional embeddings. + This option corresponds to "conv_pos" from fairseq. + Expected values are 128 for both Base and Large arch. + pos_conv_groups (int): + The number of groups of convolutional positional embeddings. + This option corresponds to "conv_pos_groups" from fairseq. + Expected values are 16 for both Base and Large arch. + num_layers (int): + The number of self attention layers in transformer block. + This option corresponds to "encoder_layers" from fairseq. + Expected values are 12 for Base and 24 for Large arch. + num_heads (int): + The number of heads in self attention layers. + This option corresponds to "encoder_attention_heads" from fairseq. + Expected values are 12 for Base and 16 for Large arch. + attention_dropout (float): + The dropout probability applied after softmax in self-attention layer. + This option corresponds to "attention_dropout" from fairseq. + Expected values are 0.1 for Base and 0.0 for Large arch. + ff_interm_features (int): + The dimension of hidden features in feed forward layer. + This option corresponds to "encoder_ffn_embed_dim" from fairseq. + Expected values are 3072 for Base and 4096 for Large arch. + ff_interm_dropout (float): + The dropout probability applied in feedforward layer. + This option correspinds to "activation_dropout" from fairseq. + Expected values are 0.1 for both Base and Large arch. + dropout (float): + The dropout probability applied at the end of feed forward layer. + This option corresponds to "dropout" from fairseq. + Expected values are 0.1 for Base and 0.0 for Large arch. + layer_norm_first (bool): + Control the order of layer norm in transformer layer and each encoder layer. + If True, in transformer layer, layer norm is applied before features are fed + to encoder layers. In encoder layer, two layer norms are applied before and after + self attention. + If False, in transformer layer, layer norm is applied after features are fed + to encoder layers. In encoder layer, two layer norms are applied after self + attention, before and after feed forward. + This option corresponds to "layer_norm_first" from fairseq. + Expected values are False for Base and True for Large arch. + layer_drop (float): + Probability to drop each encoder layer during training. + This option corresponds to "layerdrop" from fairseq. + Expected values are 0.1 for both Base and Large arch. + + See Also: + * "encoder_embed_dim" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L49-L51 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L64 + * "dropout_input" + - Def, base and large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L75-L78 + * "conv_pos" + - Def, base and large + NOTE: The description is wrong. + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L204-L207 + - Usage + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L756 + * "conv_pos_groups" + - Def, base and large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L208-L211 + * "encoder_layers" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L46-L48 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L63 + * "encoder_attention_heads" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L55-L57 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L66 + * "attention_dropout" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L66-L68 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L60 + * "encoder_ffn_embed_dim" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L52-L54 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L65 + * "activation_dropout" + - Def + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L69-L71 + - Base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L55 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L55 + * "dropout" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L63-L65 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L59 + * "layer_norm_first" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L91-L93 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L53 + * "layerdrop" + - Def + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L72-L74 + - Base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L54 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L54 + """ + feature_projection = FeatureProjection(in_features, embed_dim, dropout_input) + pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups) + + # Original impl + # https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782 + encoder_layers = nn.ModuleList() + for _ in range(num_layers): + attention = SelfAttention( + embed_dim=embed_dim, + num_heads=num_heads, + dropout=attention_dropout, + ) + feed_forward = FeedForward( + io_features=embed_dim, + intermediate_features=ff_interm_features, + intermediate_dropout=ff_interm_dropout, + output_dropout=dropout, + ) + encoder_layers.append( + EncoderLayer( + attention=attention, + dropout=dropout, + layer_norm_first=layer_norm_first, + feed_forward=feed_forward, + ) + ) + transformer = Transformer( + pos_conv_embed=pos_conv, + dropout=dropout, + layers=encoder_layers, + layer_norm_first=not layer_norm_first, + layer_drop=layer_drop, + ) + return Encoder(feature_projection, transformer) diff --git a/torchaudio/models/wav2vec2/model.py b/torchaudio/models/wav2vec2/model.py new file mode 100644 index 0000000000000000000000000000000000000000..a28742a23956b6336f4ca262202d5ab4ce092eed --- /dev/null +++ b/torchaudio/models/wav2vec2/model.py @@ -0,0 +1,590 @@ +from typing import Optional, Tuple, List + +import torch +from torch import Tensor +from torch.nn import Module + +from . import components + + +class Wav2Vec2Model(Module): + """torchaudio.models.Wav2Vec2Model(feature_extractor: torch.nn.Module, encoder: torch.nn.Module, aux: Optional[torch.nn.Module] = None) + + Encoder model used in *wav2vec 2.0* [:footcite:`baevski2020wav2vec`]. + + Note: + To build the model, please use one of the factory functions. + + Args: + feature_extractor (torch.nn.Module): + Feature extractor that extracts feature vectors from raw audio Tensor. + + encoder (torch.nn.Module): + Encoder that converts the audio features into the sequence of probability + distribution (in negative log-likelihood) over labels. + + aux (torch.nn.Module or None, optional): + Auxiliary module. If provided, the output from encoder is passed to this module. + """ # noqa: E501 + def __init__( + self, + feature_extractor: Module, + encoder: Module, + aux: Optional[Module] = None, + ): + super().__init__() + self.feature_extractor = feature_extractor + self.encoder = encoder + self.aux = aux + + @torch.jit.export + def extract_features( + self, + waveforms: Tensor, + lengths: Optional[Tensor] = None, + num_layers: Optional[int] = None, + ) -> Tuple[List[Tensor], Optional[Tensor]]: + """Extract feature vectors from raw waveforms + + This returns the list of outputs from the intermediate layers of + transformer block in encoder. + + Args: + waveforms (Tensor): Audio tensor of shape `(batch, frames)`. + lengths (Tensor or None, optional): + Indicates the valid length of each audio in the batch. + Shape: `(batch, )`. + When the ``waveforms`` contains audios with different durations, + by providing ``lengths`` argument, the model will compute + the corresponding valid output lengths and apply proper mask in + transformer attention layer. + If ``None``, it is assumed that the entire audio waveform + length is valid. + num_layers (int or None, optional): + If given, limit the number of intermediate layers to go through. + Providing `1` will stop the computation after going through one + intermediate layers. If not given, the outputs from all the + intermediate layers are returned. + + Returns: + (List[Tensor], Optional[Tensor]): + List of Tensors + Features from requested layers. + Each Tensor is of shape: `(batch, time frame, feature dimension)` + Tensor or None + If ``lengths`` argument was provided, a Tensor of shape `(batch, )` + is returned. + It indicates the valid length in time axis of each feature Tensor. + """ + x, lengths = self.feature_extractor(waveforms, lengths) + x = self.encoder.extract_features(x, lengths, num_layers) + return x, lengths + + def forward( + self, + waveforms: Tensor, + lengths: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + """Compute the sequence of probability distribution over labels. + + Args: + waveforms (Tensor): Audio tensor of shape `(batch, frames)`. + lengths (Tensor or None, optional): + Indicates the valid length of each audio in the batch. + Shape: `(batch, )`. + When the ``waveforms`` contains audios with different durations, + by providing ``lengths`` argument, the model will compute + the corresponding valid output lengths and apply proper mask in + transformer attention layer. + If ``None``, it is assumed that all the audio in ``waveforms`` + have valid length. Default: ``None``. + + Returns: + (Tensor, Optional[Tensor]): + Tensor + The sequences of probability distribution (in logit) over labels. + Shape: `(batch, frames, num labels)`. + Tensor or None + If ``lengths`` argument was provided, a Tensor of shape `(batch, )` + is returned. + It indicates the valid length in time axis of the output Tensor. + """ + x, lengths = self.feature_extractor(waveforms, lengths) + x = self.encoder(x, lengths) + if self.aux is not None: + x = self.aux(x) + return x, lengths + + +def wav2vec2_model( + extractor_mode: str, + extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]], + extractor_conv_bias: bool, + encoder_embed_dim: int, + encoder_projection_dropout: float, + encoder_pos_conv_kernel: int, + encoder_pos_conv_groups: int, + encoder_num_layers: int, + encoder_num_heads: int, + encoder_attention_dropout: float, + encoder_ff_interm_features: int, + encoder_ff_interm_dropout: float, + encoder_dropout: float, + encoder_layer_norm_first: bool, + encoder_layer_drop: float, + aux_num_out: Optional[int], +) -> Wav2Vec2Model: + # Overriding the signature so that the return type is correct on Sphinx + """wav2vec2_model(extractor_mode: str, extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]], extractor_conv_bias: bool, encoder_embed_dim: int, encoder_projection_dropout: float, encoder_pos_conv_kernel: int, encoder_pos_conv_groups: int, encoder_num_layers: int, encoder_num_heads: int, encoder_attention_dropout: float, encoder_ff_interm_features: int, encoder_ff_interm_dropout: float, encoder_dropout: float, encoder_layer_norm_first: bool, encoder_layer_drop: float, aux_num_out: Optional[int]) -> torchaudio.models.Wav2Vec2Model + + Build a custom Wav2Vec2Model + + Note: + The "feature extractor" below corresponds to + `ConvFeatureExtractionModel `__ + in the original ``fairseq`` implementation. + This is referred as "(convolutional) feature encoder" in the *wav2vec 2.0* + [:footcite:`baevski2020wav2vec`] paper. + + The "encoder" below corresponds to `TransformerEncoder `__, + and this is referred as "Transformer" in the paper. + + Args: + extractor_mode (str): Operation mode of feature extractor. + Valid values are ``"group_norm"`` or ``"layer_norm"``. + If ``"group_norm"``, then a single normalization is applied + in the first convolution block. Otherwise, all the convolution + blocks will have layer normalization. + + This option corresponds to ``extractor_mode`` from ``fairseq``. + extractor_conv_layer_config (list of integer tuples or None): + Configuration of convolution layers in feature extractor. + List of convolution configuration, + i.e. ``[(output_channel, kernel_size, stride), ...]`` + + If ``None`` is provided, then the following default value is used. + + .. code-block:: python + + [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ] + + This option corresponds to ``conv_feature_layers`` from ``fairseq``. + + extractor_conv_bias (bool): + Whether to include bias term to each convolution operation. + + This option corresponds to ``conv_bias`` from ``fairseq``. + + encoder_embed_dim (int): + The dimension of embedding in encoder. + + This option corresponds to ``encoder_embed_dim`` from ``fairseq``. + + encoder_projection_dropout (float): + The dropout probability applied after the input feature is projected + to ``encoder_embed_dim``. + + This option corresponds to ``dropout_input`` from ``fairseq``. + + encoder_pos_conv_kernel (int): + The kernel size of convolutional positional embeddings. + + This option corresponds to ``conv_pos`` from ``fairseq``. + + encoder_pos_conv_groups (int): + The number of groups of convolutional positional embeddings. + + This option corresponds to ``conv_pos_groups`` from ``fairseq``. + + encoder_num_layers (int): + The number of self attention layers in transformer block. + + This option corresponds to ``encoder_layers`` from ``fairseq``. + + encoder_num_heads (int): + The number of heads in self attention layers. + + This option corresponds to ``encoder_attention_heads`` from ``fairseq``. + + encoder_attention_dropout (float): + The dropout probability applied after softmax in self-attention layer. + + This option corresponds to ``attention_dropout`` from ``fairseq``. + + encoder_ff_interm_features (int): + The dimension of hidden features in feed forward layer. + + This option corresponds to ``encoder_ffn_embed_dim`` from ``fairseq``. + + encoder_ff_interm_dropout (float): + The dropout probability applied in feedforward layer. + + This option correspinds to ``activation_dropout`` from ``fairseq``. + + encoder_dropout (float): + The dropout probability applied at the end of feed forward layer. + + This option corresponds to ``dropout`` from ``fairseq``. + + encoder_layer_norm_first (bool): + Control the order of layer norm in transformer layer and each encoder layer. + If True, in transformer layer, layer norm is applied before features are fed + to encoder layers. In encoder layer, two layer norms are applied before and after + self attention. + If False, in transformer layer, layer norm is applied after features are fed + to encoder layers. In encoder layer, two layer norms are applied after self + attention, before and after feed forward. + + This option corresponds to ``layer_norm_first`` from ``fairseq``. + + encoder_layer_drop (float): + Probability to drop each encoder layer during training. + + This option corresponds to ``layerdrop`` from ``fairseq``. + + aux_num_out (int or None): + When provided, attach an extra linear layer on top of encoder, which can be + used for fine-tuning. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + if extractor_conv_layer_config is None: + extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2 + + feature_extractor = components._get_feature_extractor( + extractor_mode, extractor_conv_layer_config, extractor_conv_bias) + encoder = components._get_encoder( + in_features=extractor_conv_layer_config[-1][0], + embed_dim=encoder_embed_dim, + dropout_input=encoder_projection_dropout, + pos_conv_kernel=encoder_pos_conv_kernel, + pos_conv_groups=encoder_pos_conv_groups, + num_layers=encoder_num_layers, + num_heads=encoder_num_heads, + attention_dropout=encoder_attention_dropout, + ff_interm_features=encoder_ff_interm_features, + ff_interm_dropout=encoder_ff_interm_dropout, + dropout=encoder_dropout, + layer_norm_first=encoder_layer_norm_first, + layer_drop=encoder_layer_drop, + ) + aux = None + if aux_num_out is not None: + aux = torch.nn.Linear(in_features=encoder_embed_dim, out_features=aux_num_out) + return Wav2Vec2Model(feature_extractor, encoder, aux) + + +def wav2vec2_base( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.1, + encoder_ff_interm_dropout: float = 0.1, + encoder_dropout: float = 0.1, + encoder_layer_drop: float = 0.1, + aux_num_out: Optional[int] = None, +) -> Wav2Vec2Model: + # Overriding the signature so that the return type is correct on Sphinx + """wav2vec2_base(encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.1, encoder_ff_interm_dropout: float = 0.1, encoder_dropout: float = 0.1, encoder_layer_drop: float = 0.1, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model + + Build Wav2Vec2Model with "base" architecture from *wav2vec 2.0* [:footcite:`baevski2020wav2vec`] + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="group_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=768, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=12, + encoder_num_heads=12, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=3072, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=False, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + ) + + +def wav2vec2_large( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.1, + encoder_ff_interm_dropout: float = 0.1, + encoder_dropout: float = 0.1, + encoder_layer_drop: float = 0.1, + aux_num_out: Optional[int] = None, +) -> Wav2Vec2Model: + # Overriding the signature so that the return type is correct on Sphinx + """wav2vec2_large(encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.1, encoder_ff_interm_dropout: float = 0.1, encoder_dropout: float = 0.1, encoder_layer_drop: float = 0.1, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model + + Build Wav2Vec2Model with "large" architecture from *wav2vec 2.0* [:footcite:`baevski2020wav2vec`] + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="group_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=1024, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=24, + encoder_num_heads=16, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=4096, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=False, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + ) + + +def wav2vec2_large_lv60k( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.0, + encoder_ff_interm_dropout: float = 0.1, + encoder_dropout: float = 0.0, + encoder_layer_drop: float = 0.1, + aux_num_out: Optional[int] = None, +) -> Wav2Vec2Model: + # Overriding the signature so that the return type is correct on Sphinx + """wav2vec2_large_lv60k( encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.1, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.1, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model + + Build Wav2Vec2Model with "large lv-60k" architecture from *wav2vec 2.0* [:footcite:`baevski2020wav2vec`] + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=True, + encoder_embed_dim=1024, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=24, + encoder_num_heads=16, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=4096, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=True, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + ) + + +def hubert_base( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.1, + encoder_ff_interm_dropout: float = 0.0, + encoder_dropout: float = 0.1, + encoder_layer_drop: float = 0.05, + aux_num_out: Optional[int] = None, +) -> Wav2Vec2Model: + # Overriding the signature so that the return type is correct on Sphinx + """hubert_base(encoder_projection_dropout: float = 0.1, encoder_attention_dropout: float = 0.1, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.1, encoder_layer_drop: float = 0.05, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model + + Build HuBERT model with "base" architecture from *HuBERT* [:footcite:`hsu2021hubert`] + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode='group_norm', + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=768, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=12, + encoder_num_heads=12, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=3072, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=False, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + ) + + +def hubert_large( + encoder_projection_dropout: float = 0.0, + encoder_attention_dropout: float = 0.0, + encoder_ff_interm_dropout: float = 0.0, + encoder_dropout: float = 0.0, + encoder_layer_drop: float = 0.0, + aux_num_out: Optional[int] = None, +) -> Wav2Vec2Model: + # Overriding the signature so that the return type is correct on Sphinx + """hubert_large(encoder_projection_dropout: float = 0.0, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.0, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model + + Build HuBERT model with "large" architecture from *HuBERT* [:footcite:`hsu2021hubert`] + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode='layer_norm', + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=1024, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=24, + encoder_num_heads=16, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=4096, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=True, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + ) + + +def hubert_xlarge( + encoder_projection_dropout: float = 0.0, + encoder_attention_dropout: float = 0.0, + encoder_ff_interm_dropout: float = 0.0, + encoder_dropout: float = 0.0, + encoder_layer_drop: float = 0.0, + aux_num_out: Optional[int] = None, +) -> Wav2Vec2Model: + # Overriding the signature so that the return type is correct on Sphinx + """hubert_xlarge(encoder_projection_dropout: float = 0.0, encoder_attention_dropout: float = 0.0, encoder_ff_interm_dropout: float = 0.0, encoder_dropout: float = 0.0, encoder_layer_drop: float = 0.0, aux_num_out: Optional[int] = None) -> torchaudio.models.Wav2Vec2Model + + Build HuBERT model with "extra large" architecture from *HuBERT* [:footcite:`hsu2021hubert`] + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode='layer_norm', + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=1280, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=48, + encoder_num_heads=16, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=5120, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=True, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + ) diff --git a/torchaudio/models/wav2vec2/utils/__init__.py b/torchaudio/models/wav2vec2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d9fb218bb88e3e12f0e81b9e1782dea0edaa925 --- /dev/null +++ b/torchaudio/models/wav2vec2/utils/__init__.py @@ -0,0 +1,7 @@ +from .import_huggingface import import_huggingface_model +from .import_fairseq import import_fairseq_model + +__all__ = [ + 'import_huggingface_model', + 'import_fairseq_model', +] diff --git a/torchaudio/models/wav2vec2/utils/import_fairseq.py b/torchaudio/models/wav2vec2/utils/import_fairseq.py new file mode 100644 index 0000000000000000000000000000000000000000..e285d3d1e1fe3ee4d3a8678289bf16aa55d57c53 --- /dev/null +++ b/torchaudio/models/wav2vec2/utils/import_fairseq.py @@ -0,0 +1,219 @@ +"""Import fariseq's wav2vec2.0 pretrained weights to torchaudios's format. + +For this module to work, you need `fairseq`. +""" +import re + +from torch.nn import Module + +from ..model import Wav2Vec2Model, wav2vec2_model + + +def _parse_config(w2v_model): + encoder = w2v_model.encoder + conv_layers = w2v_model.feature_extractor.conv_layers + + extractor_mode = 'layer_norm' + if 'GroupNorm' in conv_layers[0][2].__class__.__name__: + extractor_mode = 'group_norm' + else: + extractor_mode = 'layer_norm' + + conv_layer_config = [(l[0].out_channels, l[0].kernel_size[0], l[0].stride[0]) for l in conv_layers] + + if all(l[0].bias is None for l in conv_layers): + conv_bias = False + elif all(l[0].bias is not None for l in conv_layers): + conv_bias = True + else: + raise ValueError( + 'Either all the convolutions layers have bias term or none of them should.') + + config = { + 'extractor_mode': extractor_mode, + 'extractor_conv_layer_config': conv_layer_config, + 'extractor_conv_bias': conv_bias, + 'encoder_embed_dim': w2v_model.post_extract_proj.out_features, + 'encoder_projection_dropout': w2v_model.dropout_input.p, + 'encoder_pos_conv_kernel': encoder.pos_conv[0].kernel_size[0], + 'encoder_pos_conv_groups': encoder.pos_conv[0].groups, + 'encoder_num_layers': len(encoder.layers), + 'encoder_num_heads': encoder.layers[0].self_attn.num_heads, + 'encoder_attention_dropout': encoder.layers[0].self_attn.dropout_module.p, + 'encoder_ff_interm_features': encoder.layers[0].fc1.out_features, + 'encoder_ff_interm_dropout': encoder.layers[0].dropout2.p, + 'encoder_dropout': encoder.layers[0].dropout3.p, + 'encoder_layer_norm_first': encoder.layer_norm_first, + 'encoder_layer_drop': encoder.layerdrop, + } + return config + + +def _map_key(key): + key_ = key + if key.startswith('w2v_model.'): + key = key.replace('w2v_model.', '') + if re.match(r'(mask_emb|quantizer|project_q|final_proj|mask_emb)', key): + return None + # Feature Extractor + # Group norm when "extractor_mode" is "default". + # (Only the first layer) + # "conv_layers.0.2.weight" -> "conv_layers.0.layer_norm.weight" + # "conv_layers.0.2.bias" -> "conv_layers.0.layer_norm.bias" + match = re.match(r'feature_extractor\.conv_layers\.0\.2\.(weight|bias)', key) + if match: + return f"feature_extractor.conv_layers.0.layer_norm.{match.group(1)}" + # Convolutions + # "conv_layers.X.0.weight" -> "conv_layers.X.conv.weight" + # "conv_layers.X.0.bias" -> "conv_layers.X.conv.bias" + match = re.match(r'feature_extractor\.conv_layers\.(\d+)\.0\.(weight|bias)', key) + if match: + return f"feature_extractor.conv_layers.{match.group(1)}.conv.{match.group(2)}" + # Layer norm when "extractor_mode" is "layer_norm". + # "conv_layers.X.2.1.weight" -> "conv_layers.X.layer_norm.weight" + # "conv_layers.X.2.1.bias" -> "conv_layers.X.layer_norm.bias" + match = re.match(r'feature_extractor\.conv_layers\.(\d+)\.2\.1\.(weight|bias)', key) + if match: + return f"feature_extractor.conv_layers.{match.group(1)}.layer_norm.{match.group(2)}" + match = re.match(r"post_extract_proj\.(weight|bias)", key) + # Encoder - Feature projection + if match: + return f"encoder.feature_projection.projection.{match.group(1)}" + match = re.match(r"layer_norm\.(weight|bias)", key) + if match: + return f"encoder.feature_projection.layer_norm.{match.group(1)}" + # Encoder - Transformer - Convolutional positional embedding + match = re.match(r"encoder\.pos_conv\.0\.(bias|weight_g|weight_v)", key) + if match: + return f"encoder.transformer.pos_conv_embed.conv.{match.group(1)}" + match = re.match(r"encoder\.layer_norm\.(weight|bias)", key) + if match: + return f"encoder.transformer.layer_norm.{match.group(1)}" + # Encoder - Transformer - Self attention layers + match = re.match(r"encoder\.layers\.(\d+)\.self_attn\.((k_|v_|q_|out_)proj\.(weight|bias))", key) + if match: + return f"encoder.transformer.layers.{match.group(1)}.attention.{match.group(2)}" + match = re.match(r"encoder\.layers\.(\d+)\.self_attn_layer_norm\.(weight|bias)", key) + if match: + return f"encoder.transformer.layers.{match.group(1)}.layer_norm.{match.group(2)}" + match = re.match(r"encoder\.layers\.(\d+)\.fc1\.(weight|bias)", key) + if match: + return f"encoder.transformer.layers.{match.group(1)}.feed_forward.intermediate_dense.{match.group(2)}" + match = re.match(r"encoder\.layers\.(\d+)\.fc2\.(weight|bias)", key) + if match: + return f"encoder.transformer.layers.{match.group(1)}.feed_forward.output_dense.{match.group(2)}" + match = re.match(r"encoder\.layers\.(\d+)\.final_layer_norm\.(weight|bias)", key) + if match: + return f"encoder.transformer.layers.{match.group(1)}.final_layer_norm.{match.group(2)}" + match = re.match(r"proj\.(weight|bias)", key) + # Auxiliary Module + # Only relevant when loading fine-tuned models + if match: + return f"aux.{match.group(1)}" + # HuBERT Extension + if key in ['label_embs_concat']: + return key + raise ValueError(f'Unexpected key: {key_}') + + +def _convert_state_dict(state_dict): + converted = {} + for k, v in state_dict.items(): + k = _map_key(k) + if k is not None: + converted[k] = v + return converted + + +def import_fairseq_model(original: Module) -> Wav2Vec2Model: + # Overriding the signature so that the types are correct on Sphinx + """import_fairseq_model(original: torch.nn.Module) -> torchaudio.models.Wav2Vec2Model + + Build Wav2Vec2Model from the corresponding model object of `fairseq`_. + + Args: + original (torch.nn.Module): + An instance of fairseq's Wav2Vec2.0 or HuBERT model. + One of ``fairseq.models.wav2vec.wav2vec2_asr.Wav2VecEncoder``, + ``fairseq.models.wav2vec.wav2vec2.Wav2Vec2Model`` or + ``fairseq.models.hubert.hubert_asr.HubertEncoder``. + + Returns: + Wav2Vec2Model: Imported model. + + Example - Loading pretrain-only model + >>> from torchaudio.models.wav2vec2.utils import import_fairseq_model + >>> + >>> # Load model using fairseq + >>> model_file = 'wav2vec_small.pt' + >>> model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file]) + >>> original = model[0] + >>> imported = import_fairseq_model(original) + >>> + >>> # Perform feature extraction + >>> waveform, _ = torchaudio.load('audio.wav') + >>> features, _ = imported.extract_features(waveform) + >>> + >>> # Compare result with the original model from fairseq + >>> reference = original.feature_extractor(waveform).transpose(1, 2) + >>> torch.testing.assert_allclose(features, reference) + + Example - Fine-tuned model + >>> from torchaudio.models.wav2vec2.utils import import_fairseq_model + >>> + >>> # Load model using fairseq + >>> model_file = 'wav2vec_small_960h.pt' + >>> model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file]) + >>> original = model[0] + >>> imported = import_fairseq_model(original.w2v_encoder) + >>> + >>> # Perform encoding + >>> waveform, _ = torchaudio.load('audio.wav') + >>> emission, _ = imported(waveform) + >>> + >>> # Compare result with the original model from fairseq + >>> mask = torch.zeros_like(waveform) + >>> reference = original(waveform, mask)['encoder_out'].transpose(0, 1) + >>> torch.testing.assert_allclose(emission, reference) + + .. _fairseq: https://github.com/pytorch/fairseq + """ + class_ = original.__class__.__name__ + if class_ == 'Wav2Vec2Model': + return _import_wav2vec2_pretraining(original) + if class_ == 'Wav2VecEncoder': + return _import_wav2vec2_finetuning(original) + if class_ == 'HubertModel': + return _import_hubert_pretraining(original) + if class_ == 'HubertEncoder': + return _import_hubert_finetuning(original) + raise ValueError( + f'Expected an instance of `Wav2Vec2Model` or `Wav2VecEncoder`. Found: {class_}') + + +def _import_wav2vec2_finetuning(original: Module) -> Wav2Vec2Model: + config = _parse_config(original.w2v_model) + model = wav2vec2_model(**config, aux_num_out=original.proj.out_features) + model.load_state_dict(_convert_state_dict(original.state_dict())) + return model + + +def _import_wav2vec2_pretraining(original: Module) -> Wav2Vec2Model: + config = _parse_config(original) + model = wav2vec2_model(**config, aux_num_out=None) + model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False) + return model + + +def _import_hubert_finetuning(original: Module) -> Wav2Vec2Model: + config = _parse_config(original.w2v_model) + model = wav2vec2_model(**config, aux_num_out=original.proj.out_features) + model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False) + return model + + +def _import_hubert_pretraining(original: Module) -> Wav2Vec2Model: + config = _parse_config(original) + model = wav2vec2_model(**config, aux_num_out=None) + model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False) + return model diff --git a/torchaudio/models/wav2vec2/utils/import_huggingface.py b/torchaudio/models/wav2vec2/utils/import_huggingface.py new file mode 100644 index 0000000000000000000000000000000000000000..c1a5c4133c10c14c966fd40bb288a855d29266ec --- /dev/null +++ b/torchaudio/models/wav2vec2/utils/import_huggingface.py @@ -0,0 +1,80 @@ +"""Import Hugging Face transformers's wav2vec2.0 pretrained weights to torchaudios's format. +""" +import logging + +from torch.nn import Module + +from ..model import Wav2Vec2Model, wav2vec2_model + +_LG = logging.getLogger(__name__) + + +def _get_config(cfg): + config = { + 'extractor_mode': f'{cfg.feat_extract_norm}_norm', + 'extractor_conv_layer_config': list(zip(cfg.conv_dim, cfg.conv_kernel, cfg.conv_stride)), + 'extractor_conv_bias': cfg.conv_bias, + 'encoder_embed_dim': cfg.hidden_size, + 'encoder_projection_dropout': cfg.feat_proj_dropout, + 'encoder_pos_conv_kernel': cfg.num_conv_pos_embeddings, + 'encoder_pos_conv_groups': cfg.num_conv_pos_embedding_groups, + 'encoder_num_layers': cfg.num_hidden_layers, + 'encoder_num_heads': cfg.num_attention_heads, + 'encoder_attention_dropout': cfg.attention_dropout, + 'encoder_ff_interm_features': cfg.intermediate_size, + 'encoder_ff_interm_dropout': cfg.activation_dropout, + 'encoder_dropout': cfg.hidden_dropout, + 'encoder_layer_norm_first': cfg.do_stable_layer_norm, + 'encoder_layer_drop': cfg.layerdrop, + } + return config + + +def _build(config, original): + if original.__class__.__name__ == 'Wav2Vec2ForCTC': + aux_num_out = original.config.vocab_size + wav2vec2 = original.wav2vec2 + else: + _LG.warning( + 'The model is not an instance of Wav2Vec2ForCTC. ' + '"lm_head" module is not imported.') + aux_num_out = None + wav2vec2 = original + imported = wav2vec2_model(**config, aux_num_out=aux_num_out) + imported.feature_extractor.load_state_dict(wav2vec2.feature_extractor.state_dict()) + imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict()) + imported.encoder.transformer.load_state_dict(wav2vec2.encoder.state_dict()) + if original.__class__.__name__ == 'Wav2Vec2ForCTC': + imported.aux.load_state_dict(original.lm_head.state_dict()) + return imported + + +def import_huggingface_model(original: Module) -> Wav2Vec2Model: + """import_huggingface_model(original: torch.nn.Module) -> torchaudio.models.Wav2Vec2Model + + Build Wav2Vec2Model from the corresponding model object of Hugging Face's `Transformers`_. + + Args: + original (torch.nn.Module): An instance of ``Wav2Vec2ForCTC`` from ``transformers``. + + Returns: + Wav2Vec2Model: Imported model. + + Example + >>> from torchaudio.models.wav2vec2.utils import import_huggingface_model + >>> + >>> original = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") + >>> model = import_huggingface_model(original) + >>> + >>> waveforms, _ = torchaudio.load("audio.wav") + >>> logits, _ = model(waveforms) + + .. _Transformers: https://huggingface.co/transformers/ + """ + _LG.info('Importing model.') + _LG.info('Loading model configuration.') + config = _get_config(original.config) + _LG.debug(' - config: %s', config) + _LG.info('Building model.') + imported = _build(config, original) + return imported diff --git a/torchaudio/models/wavernn.py b/torchaudio/models/wavernn.py new file mode 100644 index 0000000000000000000000000000000000000000..9f92061906b19556343466b623006c742cc6ef40 --- /dev/null +++ b/torchaudio/models/wavernn.py @@ -0,0 +1,411 @@ +from typing import List, Tuple, Optional +import math + +import torch +from torch import Tensor +from torch import nn +import torch.nn.functional as F + +__all__ = [ + "ResBlock", + "MelResNet", + "Stretch2d", + "UpsampleNetwork", + "WaveRNN", +] + + +class ResBlock(nn.Module): + r"""ResNet block based on *Efficient Neural Audio Synthesis* [:footcite:`kalchbrenner2018efficient`]. + + Args: + n_freq: the number of bins in a spectrogram. (Default: ``128``) + + Examples + >>> resblock = ResBlock() + >>> input = torch.rand(10, 128, 512) # a random spectrogram + >>> output = resblock(input) # shape: (10, 128, 512) + """ + + def __init__(self, n_freq: int = 128) -> None: + super().__init__() + + self.resblock_model = nn.Sequential( + nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False), + nn.BatchNorm1d(n_freq), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False), + nn.BatchNorm1d(n_freq) + ) + + def forward(self, specgram: Tensor) -> Tensor: + r"""Pass the input through the ResBlock layer. + Args: + specgram (Tensor): the input sequence to the ResBlock layer (n_batch, n_freq, n_time). + + Return: + Tensor shape: (n_batch, n_freq, n_time) + """ + + return self.resblock_model(specgram) + specgram + + +class MelResNet(nn.Module): + r"""MelResNet layer uses a stack of ResBlocks on spectrogram. + + Args: + n_res_block: the number of ResBlock in stack. (Default: ``10``) + n_freq: the number of bins in a spectrogram. (Default: ``128``) + n_hidden: the number of hidden dimensions of resblock. (Default: ``128``) + n_output: the number of output dimensions of melresnet. (Default: ``128``) + kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``) + + Examples + >>> melresnet = MelResNet() + >>> input = torch.rand(10, 128, 512) # a random spectrogram + >>> output = melresnet(input) # shape: (10, 128, 508) + """ + + def __init__(self, + n_res_block: int = 10, + n_freq: int = 128, + n_hidden: int = 128, + n_output: int = 128, + kernel_size: int = 5) -> None: + super().__init__() + + ResBlocks = [ResBlock(n_hidden) for _ in range(n_res_block)] + + self.melresnet_model = nn.Sequential( + nn.Conv1d(in_channels=n_freq, out_channels=n_hidden, kernel_size=kernel_size, bias=False), + nn.BatchNorm1d(n_hidden), + nn.ReLU(inplace=True), + *ResBlocks, + nn.Conv1d(in_channels=n_hidden, out_channels=n_output, kernel_size=1) + ) + + def forward(self, specgram: Tensor) -> Tensor: + r"""Pass the input through the MelResNet layer. + Args: + specgram (Tensor): the input sequence to the MelResNet layer (n_batch, n_freq, n_time). + + Return: + Tensor shape: (n_batch, n_output, n_time - kernel_size + 1) + """ + + return self.melresnet_model(specgram) + + +class Stretch2d(nn.Module): + r"""Upscale the frequency and time dimensions of a spectrogram. + + Args: + time_scale: the scale factor in time dimension + freq_scale: the scale factor in frequency dimension + + Examples + >>> stretch2d = Stretch2d(time_scale=10, freq_scale=5) + + >>> input = torch.rand(10, 100, 512) # a random spectrogram + >>> output = stretch2d(input) # shape: (10, 500, 5120) + """ + + def __init__(self, + time_scale: int, + freq_scale: int) -> None: + super().__init__() + + self.freq_scale = freq_scale + self.time_scale = time_scale + + def forward(self, specgram: Tensor) -> Tensor: + r"""Pass the input through the Stretch2d layer. + + Args: + specgram (Tensor): the input sequence to the Stretch2d layer (..., n_freq, n_time). + + Return: + Tensor shape: (..., n_freq * freq_scale, n_time * time_scale) + """ + + return specgram.repeat_interleave(self.freq_scale, -2).repeat_interleave(self.time_scale, -1) + + +class UpsampleNetwork(nn.Module): + r"""Upscale the dimensions of a spectrogram. + + Args: + upsample_scales: the list of upsample scales. + n_res_block: the number of ResBlock in stack. (Default: ``10``) + n_freq: the number of bins in a spectrogram. (Default: ``128``) + n_hidden: the number of hidden dimensions of resblock. (Default: ``128``) + n_output: the number of output dimensions of melresnet. (Default: ``128``) + kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``) + + Examples + >>> upsamplenetwork = UpsampleNetwork(upsample_scales=[4, 4, 16]) + >>> input = torch.rand(10, 128, 10) # a random spectrogram + >>> output = upsamplenetwork(input) # shape: (10, 1536, 128), (10, 1536, 128) + """ + + def __init__(self, + upsample_scales: List[int], + n_res_block: int = 10, + n_freq: int = 128, + n_hidden: int = 128, + n_output: int = 128, + kernel_size: int = 5) -> None: + super().__init__() + + total_scale = 1 + for upsample_scale in upsample_scales: + total_scale *= upsample_scale + self.total_scale: int = total_scale + + self.indent = (kernel_size - 1) // 2 * total_scale + self.resnet = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size) + self.resnet_stretch = Stretch2d(total_scale, 1) + + up_layers = [] + for scale in upsample_scales: + stretch = Stretch2d(scale, 1) + conv = nn.Conv2d(in_channels=1, + out_channels=1, + kernel_size=(1, scale * 2 + 1), + padding=(0, scale), + bias=False) + conv.weight.data.fill_(1. / (scale * 2 + 1)) + up_layers.append(stretch) + up_layers.append(conv) + self.upsample_layers = nn.Sequential(*up_layers) + + def forward(self, specgram: Tensor) -> Tuple[Tensor, Tensor]: + r"""Pass the input through the UpsampleNetwork layer. + + Args: + specgram (Tensor): the input sequence to the UpsampleNetwork layer (n_batch, n_freq, n_time) + + Return: + Tensor shape: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale), + (n_batch, n_output, (n_time - kernel_size + 1) * total_scale) + where total_scale is the product of all elements in upsample_scales. + """ + + resnet_output = self.resnet(specgram).unsqueeze(1) + resnet_output = self.resnet_stretch(resnet_output) + resnet_output = resnet_output.squeeze(1) + + specgram = specgram.unsqueeze(1) + upsampling_output = self.upsample_layers(specgram) + upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent:-self.indent] + + return upsampling_output, resnet_output + + +class WaveRNN(nn.Module): + r"""WaveRNN model based on the implementation from `fatchord `_. + + The original implementation was introduced in *Efficient Neural Audio Synthesis* + [:footcite:`kalchbrenner2018efficient`]. The input channels of waveform and spectrogram have to be 1. + The product of `upsample_scales` must equal `hop_length`. + + Args: + upsample_scales: the list of upsample scales. + n_classes: the number of output classes. + hop_length: the number of samples between the starts of consecutive frames. + n_res_block: the number of ResBlock in stack. (Default: ``10``) + n_rnn: the dimension of RNN layer. (Default: ``512``) + n_fc: the dimension of fully connected layer. (Default: ``512``) + kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``) + n_freq: the number of bins in a spectrogram. (Default: ``128``) + n_hidden: the number of hidden dimensions of resblock. (Default: ``128``) + n_output: the number of output dimensions of melresnet. (Default: ``128``) + + Example + >>> wavernn = WaveRNN(upsample_scales=[5,5,8], n_classes=512, hop_length=200) + >>> waveform, sample_rate = torchaudio.load(file) + >>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length) + >>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time) + >>> output = wavernn(waveform, specgram) + >>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, n_classes) + """ + + def __init__(self, + upsample_scales: List[int], + n_classes: int, + hop_length: int, + n_res_block: int = 10, + n_rnn: int = 512, + n_fc: int = 512, + kernel_size: int = 5, + n_freq: int = 128, + n_hidden: int = 128, + n_output: int = 128) -> None: + super().__init__() + + self.kernel_size = kernel_size + self._pad = (kernel_size - 1 if kernel_size % 2 else kernel_size) // 2 + self.n_rnn = n_rnn + self.n_aux = n_output // 4 + self.hop_length = hop_length + self.n_classes = n_classes + self.n_bits: int = int(math.log2(self.n_classes)) + + total_scale = 1 + for upsample_scale in upsample_scales: + total_scale *= upsample_scale + if total_scale != self.hop_length: + raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}") + + self.upsample = UpsampleNetwork(upsample_scales, + n_res_block, + n_freq, + n_hidden, + n_output, + kernel_size) + self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn) + + self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True) + self.rnn2 = nn.GRU(n_rnn + self.n_aux, n_rnn, batch_first=True) + + self.relu1 = nn.ReLU(inplace=True) + self.relu2 = nn.ReLU(inplace=True) + + self.fc1 = nn.Linear(n_rnn + self.n_aux, n_fc) + self.fc2 = nn.Linear(n_fc + self.n_aux, n_fc) + self.fc3 = nn.Linear(n_fc, self.n_classes) + + def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor: + r"""Pass the input through the WaveRNN model. + + Args: + waveform: the input waveform to the WaveRNN layer (n_batch, 1, (n_time - kernel_size + 1) * hop_length) + specgram: the input spectrogram to the WaveRNN layer (n_batch, 1, n_freq, n_time) + + Return: + Tensor: shape (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes) + """ + + assert waveform.size(1) == 1, 'Require the input channel of waveform is 1' + assert specgram.size(1) == 1, 'Require the input channel of specgram is 1' + # remove channel dimension until the end + waveform, specgram = waveform.squeeze(1), specgram.squeeze(1) + + batch_size = waveform.size(0) + h1 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device) + h2 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device) + # output of upsample: + # specgram: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale) + # aux: (n_batch, n_output, (n_time - kernel_size + 1) * total_scale) + specgram, aux = self.upsample(specgram) + specgram = specgram.transpose(1, 2) + aux = aux.transpose(1, 2) + + aux_idx = [self.n_aux * i for i in range(5)] + a1 = aux[:, :, aux_idx[0]:aux_idx[1]] + a2 = aux[:, :, aux_idx[1]:aux_idx[2]] + a3 = aux[:, :, aux_idx[2]:aux_idx[3]] + a4 = aux[:, :, aux_idx[3]:aux_idx[4]] + + x = torch.cat([waveform.unsqueeze(-1), specgram, a1], dim=-1) + x = self.fc(x) + res = x + x, _ = self.rnn1(x, h1) + + x = x + res + res = x + x = torch.cat([x, a2], dim=-1) + x, _ = self.rnn2(x, h2) + + x = x + res + x = torch.cat([x, a3], dim=-1) + x = self.fc1(x) + x = self.relu1(x) + + x = torch.cat([x, a4], dim=-1) + x = self.fc2(x) + x = self.relu2(x) + x = self.fc3(x) + + # bring back channel dimension + return x.unsqueeze(1) + + @torch.jit.export + def infer(self, specgram: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: + r"""Inference method of WaveRNN. + + This function currently only supports multinomial sampling, which assumes the + network is trained on cross entropy loss. + + Args: + specgram (Tensor): + Batch of spectrograms. Shape: `(n_batch, n_freq, n_time)`. + lengths (Tensor or None, optional): + Indicates the valid length of each audio in the batch. + Shape: `(batch, )`. + When the ``specgram`` contains spectrograms with different durations, + by providing ``lengths`` argument, the model will compute + the corresponding valid output lengths. + If ``None``, it is assumed that all the audio in ``waveforms`` + have valid length. Default: ``None``. + + Returns: + (Tensor, Optional[Tensor]): + Tensor + The inferred waveform of size `(n_batch, 1, n_time)`. + 1 stands for a single channel. + Tensor or None + If ``lengths`` argument was provided, a Tensor of shape `(batch, )` + is returned. + It indicates the valid length in time axis of the output Tensor. + """ + + device = specgram.device + dtype = specgram.dtype + + specgram = torch.nn.functional.pad(specgram, (self._pad, self._pad)) + specgram, aux = self.upsample(specgram) + if lengths is not None: + lengths = lengths * self.upsample.total_scale + + output: List[Tensor] = [] + b_size, _, seq_len = specgram.size() + + h1 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype) + h2 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype) + x = torch.zeros((b_size, 1), device=device, dtype=dtype) + + aux_split = [aux[:, self.n_aux * i: self.n_aux * (i + 1), :] for i in range(4)] + + for i in range(seq_len): + + m_t = specgram[:, :, i] + + a1_t, a2_t, a3_t, a4_t = [a[:, :, i] for a in aux_split] + + x = torch.cat([x, m_t, a1_t], dim=1) + x = self.fc(x) + _, h1 = self.rnn1(x.unsqueeze(1), h1) + + x = x + h1[0] + inp = torch.cat([x, a2_t], dim=1) + _, h2 = self.rnn2(inp.unsqueeze(1), h2) + + x = x + h2[0] + x = torch.cat([x, a3_t], dim=1) + x = F.relu(self.fc1(x)) + + x = torch.cat([x, a4_t], dim=1) + x = F.relu(self.fc2(x)) + + logits = self.fc3(x) + + posterior = F.softmax(logits, dim=1) + + x = torch.multinomial(posterior, 1).float() + # Transform label [0, 2 ** n_bits - 1] to waveform [-1, 1] + x = 2 * x / (2 ** self.n_bits - 1.0) - 1.0 + + output.append(x) + + return torch.stack(output).permute(1, 2, 0), lengths diff --git a/torchaudio/pipelines/__init__.py b/torchaudio/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..40f251d79a2b5e442fcb277da7fe11b9b29c5f6b --- /dev/null +++ b/torchaudio/pipelines/__init__.py @@ -0,0 +1,57 @@ +from ._wav2vec2 import ( + Wav2Vec2Bundle, + Wav2Vec2ASRBundle, + WAV2VEC2_BASE, + WAV2VEC2_LARGE, + WAV2VEC2_LARGE_LV60K, + WAV2VEC2_ASR_BASE_10M, + WAV2VEC2_ASR_BASE_100H, + WAV2VEC2_ASR_BASE_960H, + WAV2VEC2_ASR_LARGE_10M, + WAV2VEC2_ASR_LARGE_100H, + WAV2VEC2_ASR_LARGE_960H, + WAV2VEC2_ASR_LARGE_LV60K_10M, + WAV2VEC2_ASR_LARGE_LV60K_100H, + WAV2VEC2_ASR_LARGE_LV60K_960H, + WAV2VEC2_XLSR53, + HUBERT_BASE, + HUBERT_LARGE, + HUBERT_XLARGE, + HUBERT_ASR_LARGE, + HUBERT_ASR_XLARGE, +) +from ._tts import ( + Tacotron2TTSBundle, + TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH, + TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH, + TACOTRON2_WAVERNN_CHAR_LJSPEECH, + TACOTRON2_WAVERNN_PHONE_LJSPEECH, +) + +__all__ = [ + 'Wav2Vec2Bundle', + 'Wav2Vec2ASRBundle', + 'WAV2VEC2_BASE', + 'WAV2VEC2_LARGE', + 'WAV2VEC2_LARGE_LV60K', + 'WAV2VEC2_ASR_BASE_10M', + 'WAV2VEC2_ASR_BASE_100H', + 'WAV2VEC2_ASR_BASE_960H', + 'WAV2VEC2_ASR_LARGE_10M', + 'WAV2VEC2_ASR_LARGE_100H', + 'WAV2VEC2_ASR_LARGE_960H', + 'WAV2VEC2_ASR_LARGE_LV60K_10M', + 'WAV2VEC2_ASR_LARGE_LV60K_100H', + 'WAV2VEC2_ASR_LARGE_LV60K_960H', + 'WAV2VEC2_XLSR53', + 'HUBERT_BASE', + 'HUBERT_LARGE', + 'HUBERT_XLARGE', + 'HUBERT_ASR_LARGE', + 'HUBERT_ASR_XLARGE', + 'Tacotron2TTSBundle', + 'TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH', + 'TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH', + 'TACOTRON2_WAVERNN_CHAR_LJSPEECH', + 'TACOTRON2_WAVERNN_PHONE_LJSPEECH', +] diff --git a/torchaudio/pipelines/_tts/__init__.py b/torchaudio/pipelines/_tts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c361707c5be5a86f113c99aaa67f7491d2c717a6 --- /dev/null +++ b/torchaudio/pipelines/_tts/__init__.py @@ -0,0 +1,16 @@ +from .interface import Tacotron2TTSBundle +from .impl import ( + TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH, + TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH, + TACOTRON2_WAVERNN_CHAR_LJSPEECH, + TACOTRON2_WAVERNN_PHONE_LJSPEECH, +) + + +__all__ = [ + 'Tacotron2TTSBundle', + 'TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH', + 'TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH', + 'TACOTRON2_WAVERNN_CHAR_LJSPEECH', + 'TACOTRON2_WAVERNN_PHONE_LJSPEECH', +] diff --git a/torchaudio/pipelines/_tts/impl.py b/torchaudio/pipelines/_tts/impl.py new file mode 100644 index 0000000000000000000000000000000000000000..c73bf7a73acec6ec117480db0bfda0520a41486b --- /dev/null +++ b/torchaudio/pipelines/_tts/impl.py @@ -0,0 +1,356 @@ +from dataclasses import dataclass +import re +from typing import Union, Optional, Dict, Any, Tuple, List + +import torch +from torch import Tensor +from torch.hub import load_state_dict_from_url + +from torchaudio.models import Tacotron2, WaveRNN +from torchaudio.functional import mu_law_decoding +from torchaudio.transforms import InverseMelScale, GriffinLim +from . import utils +from .interface import Tacotron2TTSBundle + +__all__ = [] + +_BASE_URL = 'https://download.pytorch.org/torchaudio/models' + + +################################################################################ +# Pipeline implementation - Text Processor +################################################################################ + + +class _EnglishCharProcessor(Tacotron2TTSBundle.TextProcessor): + def __init__(self): + super().__init__() + self._tokens = utils._get_chars() + self._mapping = {s: i for i, s in enumerate(self._tokens)} + + @property + def tokens(self): + return self._tokens + + def __call__(self, texts: Union[str, List[str]]) -> Tuple[Tensor, Tensor]: + if isinstance(texts, str): + texts = [texts] + indices = [[self._mapping[c] for c in t.lower() if c in self._mapping] for t in texts] + return utils._to_tensor(indices) + + +class _EnglishPhoneProcessor(Tacotron2TTSBundle.TextProcessor): + def __init__(self, *, dl_kwargs=None): + super().__init__() + self._tokens = utils._get_phones() + self._mapping = {p: i for i, p in enumerate(self._tokens)} + self._phonemizer = utils._load_phonemizer( + 'en_us_cmudict_forward.pt', dl_kwargs=dl_kwargs) + self._pattern = r"(\[[A-Z]+?\]|[_!'(),.:;? -])" + + @property + def tokens(self): + return self._tokens + + def __call__(self, texts: Union[str, List[str]]) -> Tuple[Tensor, Tensor]: + if isinstance(texts, str): + texts = [texts] + + indices = [] + for phones in self._phonemizer(texts, lang='en_us'): + # '[F][UW][B][AA][R]!' -> ['F', 'UW', 'B', 'AA', 'R', '!'] + ret = [re.sub(r'[\[\]]', '', r) for r in re.findall(self._pattern, phones)] + indices.append([self._mapping[p] for p in ret]) + return utils._to_tensor(indices) + + +################################################################################ +# Pipeline implementation - Vocoder +################################################################################ + +class _WaveRNNVocoder(torch.nn.Module, Tacotron2TTSBundle.Vocoder): + def __init__( + self, + model: WaveRNN, + min_level_db: Optional[float] = -100 + ): + super().__init__() + self._sample_rate = 22050 + self._model = model + self._min_level_db = min_level_db + + @property + def sample_rate(self): + return self._sample_rate + + def forward(self, mel_spec, lengths=None): + mel_spec = torch.exp(mel_spec) + mel_spec = 20 * torch.log10(torch.clamp(mel_spec, min=1e-5)) + if self._min_level_db is not None: + mel_spec = (self._min_level_db - mel_spec) / self._min_level_db + mel_spec = torch.clamp(mel_spec, min=0, max=1) + waveform, lengths = self._model.infer(mel_spec, lengths) + waveform = utils._unnormalize_waveform(waveform, self._model.n_bits) + waveform = mu_law_decoding(waveform, self._model.n_classes) + waveform = waveform.squeeze(1) + return waveform, lengths + + +class _GriffinLimVocoder(torch.nn.Module, Tacotron2TTSBundle.Vocoder): + def __init__(self): + super().__init__() + self._sample_rate = 22050 + self._inv_mel = InverseMelScale( + n_stft=(1024 // 2 + 1), + n_mels=80, + sample_rate=self.sample_rate, + f_min=0., + f_max=8000., + mel_scale="slaney", + norm='slaney', + ) + self._griffin_lim = GriffinLim( + n_fft=1024, + power=1, + hop_length=256, + win_length=1024, + ) + + @property + def sample_rate(self): + return self._sample_rate + + def forward(self, mel_spec, lengths=None): + mel_spec = torch.exp(mel_spec) + mel_spec = mel_spec.clone().detach().requires_grad_(True) + spec = self._inv_mel(mel_spec) + spec = spec.detach().requires_grad_(False) + waveforms = self._griffin_lim(spec) + return waveforms, lengths + + +################################################################################ +# Bundle classes mixins +################################################################################ + + +class _CharMixin: + def get_text_processor(self) -> Tacotron2TTSBundle.TextProcessor: + return _EnglishCharProcessor() + + +class _PhoneMixin: + def get_text_processor(self, *, dl_kwargs=None) -> Tacotron2TTSBundle.TextProcessor: + return _EnglishPhoneProcessor(dl_kwargs=dl_kwargs) + + +@dataclass +class _Tacotron2Mixin: + _tacotron2_path: str + _tacotron2_params: Dict[str, Any] + + def get_tacotron2(self, *, dl_kwargs=None) -> Tacotron2: + model = Tacotron2(**self._tacotron2_params) + url = f'{_BASE_URL}/{self._tacotron2_path}' + dl_kwargs = {} if dl_kwargs is None else dl_kwargs + state_dict = load_state_dict_from_url(url, **dl_kwargs) + model.load_state_dict(state_dict) + model.eval() + return model + + +@dataclass +class _WaveRNNMixin: + _wavernn_path: Optional[str] + _wavernn_params: Optional[Dict[str, Any]] + + def get_vocoder(self, *, dl_kwargs=None): + wavernn = self._get_wavernn(dl_kwargs=dl_kwargs) + return _WaveRNNVocoder(wavernn) + + def _get_wavernn(self, *, dl_kwargs=None): + model = WaveRNN(**self._wavernn_params) + url = f'{_BASE_URL}/{self._wavernn_path}' + dl_kwargs = {} if dl_kwargs is None else dl_kwargs + state_dict = load_state_dict_from_url(url, **dl_kwargs) + model.load_state_dict(state_dict) + model.eval() + return model + + +class _GriffinLimMixin: + def get_vocoder(self, **_): + return _GriffinLimVocoder() + + +################################################################################ +# Bundle classes +################################################################################ + + +@dataclass +class _Tacotron2WaveRNNCharBundle(_WaveRNNMixin, _Tacotron2Mixin, _CharMixin, Tacotron2TTSBundle): + pass + + +@dataclass +class _Tacotron2WaveRNNPhoneBundle(_WaveRNNMixin, _Tacotron2Mixin, _PhoneMixin, Tacotron2TTSBundle): + pass + + +@dataclass +class _Tacotron2GriffinLimCharBundle(_GriffinLimMixin, _Tacotron2Mixin, _CharMixin, Tacotron2TTSBundle): + pass + + +@dataclass +class _Tacotron2GriffinLimPhoneBundle(_GriffinLimMixin, _Tacotron2Mixin, _PhoneMixin, Tacotron2TTSBundle): + pass + + +################################################################################ +# Instantiate bundle objects +################################################################################ + + +TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH = _Tacotron2GriffinLimCharBundle( + _tacotron2_path='tacotron2_english_characters_1500_epochs_ljspeech.pth', + _tacotron2_params=utils._get_taco_params(n_symbols=38), +) +TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH.__doc__ = ( + '''Character-based TTS pipeline with :py:class:`torchaudio.models.Tacotron2` and +:py:class:`torchaudio.transforms.GriffinLim`. + +The text processor encodes the input texts character-by-character. + +Tacotron2 was trained on *LJSpeech* [:footcite:`ljspeech17`] for 1,500 epochs. +You can find the training script `here `__. +The default parameters were used. + +The vocoder is based on :py:class:`torchaudio.transforms.GriffinLim`. + +Please refer to :func:`torchaudio.pipelines.Tacotron2TTSBundle` for the usage. + +Example - "Hello world! T T S stands for Text to Speech!" + + .. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH.png + :alt: Spectrogram generated by Tacotron2 + + .. raw:: html + + +''') # noqa: E501 + +TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH = _Tacotron2GriffinLimPhoneBundle( + _tacotron2_path='tacotron2_english_phonemes_1500_epochs_ljspeech.pth', + _tacotron2_params=utils._get_taco_params(n_symbols=96), +) +TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH.__doc__ = ( + '''Phoneme-based TTS pipeline with :py:class:`torchaudio.models.Tacotron2` and +:py:class:`torchaudio.transforms.GriffinLim`. + +The text processor encodes the input texts based on phoneme. +It uses `DeepPhonemizer `__ to convert +graphemes to phonemes. +The model (*en_us_cmudict_forward*) was trained on +`CMUDict `__. + +Tacotron2 was trained on *LJSpeech* [:footcite:`ljspeech17`] for 1,500 epochs. +You can find the training script `here `__. +The text processor is set to the *"english_phonemes"*. + +The vocoder is based on :py:class:`torchaudio.transforms.GriffinLim`. + +Please refer to :func:`torchaudio.pipelines.Tacotron2TTSBundle` for the usage. + +Example - "Hello world! T T S stands for Text to Speech!" + + .. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH.png + :alt: Spectrogram generated by Tacotron2 + + .. raw:: html + + +''') # noqa: E501 + +TACOTRON2_WAVERNN_CHAR_LJSPEECH = _Tacotron2WaveRNNCharBundle( + _tacotron2_path='tacotron2_english_characters_1500_epochs_wavernn_ljspeech.pth', + _tacotron2_params=utils._get_taco_params(n_symbols=38), + _wavernn_path='wavernn_10k_epochs_8bits_ljspeech.pth', + _wavernn_params=utils._get_wrnn_params(), +) +TACOTRON2_WAVERNN_CHAR_LJSPEECH.__doc__ = ( + '''Character-based TTS pipeline with :py:class:`torchaudio.models.Tacotron2` and +:py:class:`torchaudio.models.WaveRNN`. + +The text processor encodes the input texts character-by-character. + +Tacotron2 was trained on *LJSpeech* [:footcite:`ljspeech17`] for 1,500 epochs. +You can find the training script `here `__. +The following parameters were used; ``win_length=1100``, ``hop_length=275``, ``n_fft=2048``, +``mel_fmin=40``, and ``mel_fmax=11025``. + +The vocder is based on :py:class:`torchaudio.models.WaveRNN`. +It was trained on 8 bits depth waveform of *LJSpeech* [:footcite:`ljspeech17`] for 10,000 epochs. +You can find the training script `here `__. + +Please refer to :func:`torchaudio.pipelines.Tacotron2TTSBundle` for the usage. + +Example - "Hello world! T T S stands for Text to Speech!" + + .. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_CHAR_LJSPEECH.png + :alt: Spectrogram generated by Tacotron2 + + .. raw:: html + + +''') # noqa: E501 + +TACOTRON2_WAVERNN_PHONE_LJSPEECH = _Tacotron2WaveRNNPhoneBundle( + _tacotron2_path='tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech.pth', + _tacotron2_params=utils._get_taco_params(n_symbols=96), + _wavernn_path='wavernn_10k_epochs_8bits_ljspeech.pth', + _wavernn_params=utils._get_wrnn_params(), +) +TACOTRON2_WAVERNN_PHONE_LJSPEECH.__doc__ = ( + '''Phoneme-based TTS pipeline with :py:class:`torchaudio.models.Tacotron2` and +:py:class:`torchaudio.models.WaveRNN`. + +The text processor encodes the input texts based on phoneme. +It uses `DeepPhonemizer `__ to convert +graphemes to phonemes. +The model (*en_us_cmudict_forward*) was trained on +`CMUDict `__. + +Tacotron2 was trained on *LJSpeech* [:footcite:`ljspeech17`] for 1,500 epochs. +You can find the training script `here `__. +The following parameters were used; ``win_length=1100``, ``hop_length=275``, ``n_fft=2048``, +``mel_fmin=40``, and ``mel_fmax=11025``. + +The vocder is based on :py:class:`torchaudio.models.WaveRNN`. +It was trained on 8 bits depth waveform of *LJSpeech* [:footcite:`ljspeech17`] for 10,000 epochs. +You can find the training script `here `__. + +Please refer to :func:`torchaudio.pipelines.Tacotron2TTSBundle` for the usage. + +Example - "Hello world! T T S stands for Text to Speech!" + + .. image:: https://download.pytorch.org/torchaudio/doc-assets/TACOTRON2_WAVERNN_PHONE_LJSPEECH.png + :alt: Spectrogram generated by Tacotron2 + + .. raw:: html + + +''') # noqa: E501 diff --git a/torchaudio/pipelines/_tts/interface.py b/torchaudio/pipelines/_tts/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..e78e78331dbc74fe53209e786d0dba1a268cad84 --- /dev/null +++ b/torchaudio/pipelines/_tts/interface.py @@ -0,0 +1,272 @@ +from abc import ABC, abstractmethod +from typing import Union, List, Tuple, Optional + +from torch import Tensor + +from torchaudio.models import Tacotron2 + + +class _TextProcessor(ABC): + @property + @abstractmethod + def tokens(self): + """The tokens that the each value in the processed tensor represent. + + See :func:`torchaudio.pipelines.Tacotron2TTSBundle.get_text_processor` for the usage. + + :type: List[str] + """ + + @abstractmethod + def __call__(self, texts: Union[str, List[str]]) -> Tuple[Tensor, Tensor]: + """Encode the given (batch of) texts into numerical tensors + + See :func:`torchaudio.pipelines.Tacotron2TTSBundle.get_text_processor` for the usage. + + Args: + text (str or list of str): The input texts. + + Returns: + (Tensor, Tensor): + Tensor: + The encoded texts. Shape: `(batch, max length)` + Tensor: + The valid length of each sample in the batch. Shape: `(batch, )`. + """ + + +class _Vocoder(ABC): + @property + @abstractmethod + def sample_rate(self): + """The sample rate of the resulting waveform + + See :func:`torchaudio.pipelines.Tacotron2TTSBundle.get_vocoder` for the usage. + + :type: float + """ + + @abstractmethod + def __call__(self, specgrams: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: + """Generate waveform from the given input, such as spectrogram + + See :func:`torchaudio.pipelines.Tacotron2TTSBundle.get_vocoder` for the usage. + + Args: + specgrams (Tensor): + The input spectrogram. Shape: `(batch, frequency bins, time)`. + The expected shape depends on the implementation. + lengths (Tensor, or None, optional): + The valid length of each sample in the batch. Shape: `(batch, )`. + (Default: `None`) + + Returns: + (Tensor, Optional[Tensor]): + Tensor: + The generated waveform. Shape: `(batch, max length)` + Tensor or None: + The valid length of each sample in the batch. Shape: `(batch, )`. + """ + + +class Tacotron2TTSBundle(ABC): + """Data class that bundles associated information to use pretrained Tacotron2 and vocoder. + + This class provides interfaces for instantiating the pretrained model along with + the information necessary to retrieve pretrained weights and additional data + to be used with the model. + + Torchaudio library instantiates objects of this class, each of which represents + a different pretrained model. Client code should access pretrained models via these + instances. + + Please see below for the usage and the available values. + + Example - Character-based TTS pipeline with Tacotron2 and WaveRNN + >>> import torchaudio + >>> + >>> text = "Hello, T T S !" + >>> bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH + >>> + >>> # Build processor, Tacotron2 and WaveRNN model + >>> processor = bundle.get_text_processor() + >>> tacotron2 = bundle.get_tacotron2() + Downloading: + 100%|███████████████████████████████| 107M/107M [00:01<00:00, 87.9MB/s] + >>> vocoder = bundle.get_vocoder() + Downloading: + 100%|███████████████████████████████| 16.7M/16.7M [00:00<00:00, 78.1MB/s] + >>> + >>> # Encode text + >>> input, lengths = processor(text) + >>> + >>> # Generate (mel-scale) spectrogram + >>> specgram, lengths, _ = tacotron2.infer(input, lengths) + >>> + >>> # Convert spectrogram to waveform + >>> waveforms, lengths = vocoder(specgram, lengths) + >>> + >>> torchaudio.save('hello-tts.wav', waveforms[0], vocoder.sample_rate) + + Example - Phoneme-based TTS pipeline with Tacotron2 and WaveRNN + >>> + >>> # Note: + >>> # This bundle uses pre-trained DeepPhonemizer as + >>> # the text pre-processor. + >>> # Please install deep-phonemizer. + >>> # See https://github.com/as-ideas/DeepPhonemizer + >>> # The pretrained weight is automatically downloaded. + >>> + >>> import torchaudio + >>> + >>> text = "Hello, TTS!" + >>> bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONEME_LJSPEECH + >>> + >>> # Build processor, Tacotron2 and WaveRNN model + >>> processor = bundle.get_text_processor() + Downloading: + 100%|███████████████████████████████| 63.6M/63.6M [00:04<00:00, 15.3MB/s] + >>> tacotron2 = bundle.get_tacotron2() + Downloading: + 100%|███████████████████████████████| 107M/107M [00:01<00:00, 87.9MB/s] + >>> vocoder = bundle.get_vocoder() + Downloading: + 100%|███████████████████████████████| 16.7M/16.7M [00:00<00:00, 78.1MB/s] + >>> + >>> # Encode text + >>> input, lengths = processor(text) + >>> + >>> # Generate (mel-scale) spectrogram + >>> specgram, lengths, _ = tacotron2.infer(input, lengths) + >>> + >>> # Convert spectrogram to waveform + >>> waveforms, lengths = vocoder(specgram, lengths) + >>> + >>> torchaudio.save('hello-tts.wav', waveforms[0], vocoder.sample_rate) + """ + + # Using the inner class so that these interfaces are not directly exposed on + # `torchaudio.pipelines`, but still listed in documentation. + # The thing is, text processing and vocoder are generic and we do not know what kind of + # new text processing and vocoder will be added in the future, so we want to make these + # interfaces specific to this Tacotron2TTS pipeline. + class TextProcessor(_TextProcessor): + """Interface of the text processing part of Tacotron2TTS pipeline + + See :func:`torchaudio.pipelines.Tacotron2TTSBundle.get_text_processor` for the usage. + """ + + class Vocoder(_Vocoder): + """Interface of the vocoder part of Tacotron2TTS pipeline + + See :func:`torchaudio.pipelines.Tacotron2TTSBundle.get_vocoder` for the usage. + """ + + @abstractmethod + def get_text_processor(self, *, dl_kwargs=None) -> TextProcessor: + # Overriding the signature so that the return type is correct on Sphinx + """get_text_processor(self, *, dl_kwargs=None) -> torchaudio.pipelines.Tacotron2TTSBundle.TextProcessor + + Create a text processor + + For character-based pipeline, this processor splits the input text by character. + For phoneme-based pipeline, this processor converts the input text (grapheme) to + phonemes. + + If a pre-trained weight file is necessary, + :func:`torch.hub.download_url_to_file` is used to downloaded it. + + Args: + dl_kwargs (dictionary of keyword arguments,): + Passed to :func:`torch.hub.download_url_to_file`. + + Returns: + TTSTextProcessor: + A callable which takes a string or a list of strings as input and + returns Tensor of encoded texts and Tensor of valid lengths. + The object also has ``tokens`` property, which allows to recover the + tokenized form. + + Example - Character-based + >>> text = [ + >>> "Hello World!", + >>> "Text-to-speech!", + >>> ] + >>> bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH + >>> processor = bundle.get_text_processor() + >>> input, lengths = processor(text) + >>> + >>> print(input) + tensor([[19, 16, 23, 23, 26, 11, 34, 26, 29, 23, 15, 2, 0, 0, 0], + [31, 16, 35, 31, 1, 31, 26, 1, 30, 27, 16, 16, 14, 19, 2]], + dtype=torch.int32) + >>> + >>> print(lengths) + tensor([12, 15], dtype=torch.int32) + >>> + >>> print([processor.tokens[i] for i in input[0, :lengths[0]]]) + ['h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', '!'] + >>> print([processor.tokens[i] for i in input[1, :lengths[1]]]) + ['t', 'e', 'x', 't', '-', 't', 'o', '-', 's', 'p', 'e', 'e', 'c', 'h', '!'] + + Example - Phoneme-based + >>> text = [ + >>> "Hello, T T S !", + >>> "Text-to-speech!", + >>> ] + >>> bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH + >>> processor = bundle.get_text_processor() + Downloading: + 100%|███████████████████████████████| 63.6M/63.6M [00:04<00:00, 15.3MB/s] + >>> input, lengths = processor(text) + >>> + >>> print(input) + tensor([[54, 20, 65, 69, 11, 92, 44, 65, 38, 2, 0, 0, 0, 0], + [81, 40, 64, 79, 81, 1, 81, 20, 1, 79, 77, 59, 37, 2]], + dtype=torch.int32) + >>> + >>> print(lengths) + tensor([10, 14], dtype=torch.int32) + >>> + >>> print([processor.tokens[i] for i in input[0]]) + ['HH', 'AH', 'L', 'OW', ' ', 'W', 'ER', 'L', 'D', '!', '_', '_', '_', '_'] + >>> print([processor.tokens[i] for i in input[1]]) + ['T', 'EH', 'K', 'S', 'T', '-', 'T', 'AH', '-', 'S', 'P', 'IY', 'CH', '!'] + """ + + @abstractmethod + def get_vocoder(self, *, dl_kwargs=None) -> Vocoder: + # Overriding the signature so that the return type is correct on Sphinx + """get_vocoder(self, *, dl_kwargs=None) -> torchaudio.pipelines.Tacotron2TTSBundle.Vocoder + + Create a vocoder module, based off of either WaveRNN or GriffinLim. + + If a pre-trained weight file is necessary, + :func:`torch.hub.load_state_dict_from_url` is used to downloaded it. + + Args: + dl_kwargs (dictionary of keyword arguments): + Passed to :func:`torch.hub.load_state_dict_from_url`. + + Returns: + Callable[[Tensor, Optional[Tensor]], Tuple[Tensor, Optional[Tensor]]]: + A vocoder module, which takes spectrogram Tensor and an optional + length Tensor, then returns resulting waveform Tensor and an optional + length Tensor. + """ + + @abstractmethod + def get_tacotron2(self, *, dl_kwargs=None) -> Tacotron2: + # Overriding the signature so that the return type is correct on Sphinx + """get_tacotron2(self, *, dl_kwargs=None) -> torchaudio.models.Tacotron2 + + Create a Tacotron2 model with pre-trained weight. + + Args: + dl_kwargs (dictionary of keyword arguments): + Passed to :func:`torch.hub.load_state_dict_from_url`. + + Returns: + Tacotron2: + The resulting model. + """ diff --git a/torchaudio/pipelines/_tts/utils.py b/torchaudio/pipelines/_tts/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..54a5c9a42d601fc2aa81d09f2b11b5c59f0c2f03 --- /dev/null +++ b/torchaudio/pipelines/_tts/utils.py @@ -0,0 +1,229 @@ +import os +import logging + +import torch + +from torchaudio._internal import module_utils as _mod_utils + + +def _get_chars(): + return ( + '_', + '-', + '!', + "'", + '(', + ')', + ',', + '.', + ':', + ';', + '?', + ' ', + 'a', + 'b', + 'c', + 'd', + 'e', + 'f', + 'g', + 'h', + 'i', + 'j', + 'k', + 'l', + 'm', + 'n', + 'o', + 'p', + 'q', + 'r', + 's', + 't', + 'u', + 'v', + 'w', + 'x', + 'y', + 'z', + ) + + +def _get_phones(): + return ( + "_", + "-", + "!", + "'", + "(", + ")", + ",", + ".", + ":", + ";", + "?", + " ", + "AA", + "AA0", + "AA1", + "AA2", + "AE", + "AE0", + "AE1", + "AE2", + "AH", + "AH0", + "AH1", + "AH2", + "AO", + "AO0", + "AO1", + "AO2", + "AW", + "AW0", + "AW1", + "AW2", + "AY", + "AY0", + "AY1", + "AY2", + "B", + "CH", + "D", + "DH", + "EH", + "EH0", + "EH1", + "EH2", + "ER", + "ER0", + "ER1", + "ER2", + "EY", + "EY0", + "EY1", + "EY2", + "F", + "G", + "HH", + "IH", + "IH0", + "IH1", + "IH2", + "IY", + "IY0", + "IY1", + "IY2", + "JH", + "K", + "L", + "M", + "N", + "NG", + "OW", + "OW0", + "OW1", + "OW2", + "OY", + "OY0", + "OY1", + "OY2", + "P", + "R", + "S", + "SH", + "T", + "TH", + "UH", + "UH0", + "UH1", + "UH2", + "UW", + "UW0", + "UW1", + "UW2", + "V", + "W", + "Y", + "Z", + "ZH" + ) + + +def _to_tensor(indices): + lengths = torch.tensor([len(i) for i in indices], dtype=torch.int32) + values = [torch.tensor(i) for i in indices] + values = torch.nn.utils.rnn.pad_sequence(values, batch_first=True) + return values, lengths + + +def _load_phonemizer(file, dl_kwargs): + if not _mod_utils.is_module_available('dp'): + raise RuntimeError('DeepPhonemizer is not installed. Please install it.') + + from dp.phonemizer import Phonemizer + + # By default, dp issues DEBUG level log. + logger = logging.getLogger('dp') + orig_level = logger.level + logger.setLevel(logging.INFO) + try: + url = f'https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/{file}' + directory = os.path.join(torch.hub.get_dir(), 'checkpoints') + os.makedirs(directory, exist_ok=True) + path = os.path.join(directory, file) + if not os.path.exists(path): + dl_kwargs = {} if dl_kwargs is None else dl_kwargs + torch.hub.download_url_to_file(url, path, **dl_kwargs) + return Phonemizer.from_checkpoint(path) + finally: + logger.setLevel(orig_level) + + +def _unnormalize_waveform(waveform: torch.Tensor, bits: int) -> torch.Tensor: + r"""Transform waveform [-1, 1] to label [0, 2 ** bits - 1]""" + waveform = torch.clamp(waveform, -1, 1) + waveform = (waveform + 1.0) * (2 ** bits - 1) / 2 + return torch.clamp(waveform, 0, 2 ** bits - 1).int() + + +def _get_taco_params(n_symbols): + return { + 'mask_padding': False, + 'n_mels': 80, + 'n_frames_per_step': 1, + 'symbol_embedding_dim': 512, + 'encoder_embedding_dim': 512, + 'encoder_n_convolution': 3, + 'encoder_kernel_size': 5, + 'decoder_rnn_dim': 1024, + 'decoder_max_step': 2000, + 'decoder_dropout': 0.1, + 'decoder_early_stopping': True, + 'attention_rnn_dim': 1024, + 'attention_hidden_dim': 128, + 'attention_location_n_filter': 32, + 'attention_location_kernel_size': 31, + 'attention_dropout': 0.1, + 'prenet_dim': 256, + 'postnet_n_convolution': 5, + 'postnet_kernel_size': 5, + 'postnet_embedding_dim': 512, + 'gate_threshold': 0.5, + 'n_symbol': n_symbols, + } + + +def _get_wrnn_params(): + return { + 'upsample_scales': [5, 5, 11], + 'n_classes': 2 ** 8, # n_bits = 8 + 'hop_length': 275, + 'n_res_block': 10, + 'n_rnn': 512, + 'n_fc': 512, + 'kernel_size': 5, + 'n_freq': 80, + 'n_hidden': 128, + 'n_output': 128 + } diff --git a/torchaudio/pipelines/_wav2vec2.py b/torchaudio/pipelines/_wav2vec2.py new file mode 100644 index 0000000000000000000000000000000000000000..69586084552dda610887662e918d7182141ec5dc --- /dev/null +++ b/torchaudio/pipelines/_wav2vec2.py @@ -0,0 +1,1001 @@ +from dataclasses import dataclass +from typing import Dict, Tuple, Any + +from torch.hub import load_state_dict_from_url + +from torchaudio.models import wav2vec2_model, Wav2Vec2Model + +__all__ = [] + + +@dataclass +class Wav2Vec2Bundle: + """torchaudio.pipelines.Wav2Vec2Bundle() + + Data class that bundles associated information to use pretrained Wav2Vec2Model. + + This class provides interfaces for instantiating the pretrained model along with + the information necessary to retrieve pretrained weights and additional data + to be used with the model. + + Torchaudio library instantiates objects of this class, each of which represents + a different pretrained model. Client code should access pretrained models via these + instances. + + Please see below for the usage and the available values. + + Example - Feature Extraction + >>> import torchaudio + >>> + >>> bundle = torchaudio.pipelines.HUBERT_BASE + >>> + >>> # Build the model and load pretrained weight. + >>> model = bundle.get_model() + Downloading: + 100%|███████████████████████████████| 360M/360M [00:06<00:00, 60.6MB/s] + >>> + >>> # Resample audio to the expected sampling rate + >>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate) + >>> + >>> # Extract acoustic features + >>> features, _ = model.extract_features(waveform) + """ # noqa: E501 + _path: str + _params: Dict[str, Any] + _sample_rate: float + + @property + def sample_rate(self) -> float: + """Sample rate of the audio that the model is trained on. + + :type: float + """ + return self._sample_rate + + def get_model(self, *, dl_kwargs=None) -> Wav2Vec2Model: + # Overriding the signature so that the return type is correct on Sphinx + """get_model(self, *, dl_kwargs=None) -> torchaudio.models.Wav2Vec2Model + + Construct the model and load the pretrained weight. + + The weight file is downloaded from the internet and cached with + :func:`torch.hub.load_state_dict_from_url` + + Args: + dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. + """ + model = wav2vec2_model(**self._params) + url = f'https://download.pytorch.org/torchaudio/models/{self._path}' + dl_kwargs = {} if dl_kwargs is None else dl_kwargs + state_dict = load_state_dict_from_url(url, **dl_kwargs) + model.load_state_dict(state_dict) + model.eval() + return model + + +@dataclass +class Wav2Vec2ASRBundle(Wav2Vec2Bundle): + """torchaudio.pipelines.Wav2Vec2ASRBundle() + + Data class that bundles associated information to use pretrained Wav2Vec2Model. + + This class provides interfaces for instantiating the pretrained model along with + the information necessary to retrieve pretrained weights and additional data + to be used with the model. + + Torchaudio library instantiates objects of this class, each of which represents + a different pretrained model. Client code should access pretrained models via these + instances. + + Please see below for the usage and the available values. + + Example - ASR + >>> import torchaudio + >>> + >>> bundle = torchaudio.pipelines.HUBERT_ASR_LARGE + >>> + >>> # Build the model and load pretrained weight. + >>> model = bundle.get_model() + Downloading: + 100%|███████████████████████████████| 1.18G/1.18G [00:17<00:00, 73.8MB/s] + >>> + >>> # Check the corresponding labels of the output. + >>> labels = bundle.get_labels() + >>> print(labels) + ('', '', '', '', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z') + >>> + >>> # Resample audio to the expected sampling rate + >>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate) + >>> + >>> # Infer the label probability distribution + >>> emissions, _ = model(waveform) + >>> + >>> # Pass emission to decoder + >>> # `ctc_decode` is for illustration purpose only + >>> transcripts = ctc_decode(emissions, labels) + """ # noqa: E501 + _labels: Tuple[str] + + def get_labels( + self, + *, + bos: str = '', + pad: str = '', + eos: str = '', + unk: str = '', + ) -> Tuple[str]: + """The output class labels (only applicable to fine-tuned bundles) + + The first four tokens are BOS, padding, EOS and UNK tokens and they can be customized. + + Args: + bos (str, optional): Beginning of sentence token. (default: ``''``) + pad (str, optional): Padding token. (default: ``''``) + eos (str, optional): End of sentence token. (default: ``''``) + unk (str, optional): Token for unknown class. (default: ``''``) + + Returns: + Tuple[str]: + For models fine-tuned on ASR, returns the tuple of strings representing + the output class labels. + + Example + >>> import torchaudio + >>> torchaudio.models.HUBERT_ASR_LARGE.get_labels() + ('', '', '', '', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z') + """ # noqa: E501 + if self._labels is None: + raise ValueError('Pre-trained models do not have labels.') + return (bos, pad, eos, unk, *self._labels) + + +def _get_labels(): + return ( + '|', + 'E', + 'T', + 'A', + 'O', + 'N', + 'I', + 'H', + 'S', + 'R', + 'D', + 'L', + 'U', + 'M', + 'W', + 'C', + 'F', + 'G', + 'Y', + 'P', + 'B', + 'V', + 'K', + "'", + 'X', + 'J', + 'Q', + 'Z', + ) + + +WAV2VEC2_BASE = Wav2Vec2Bundle( + _path='wav2vec2_fairseq_base_ls960.pth', + _params={ + 'extractor_mode': 'group_norm', + 'extractor_conv_layer_config': [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ], + 'extractor_conv_bias': False, + 'encoder_embed_dim': 768, + 'encoder_projection_dropout': 0.1, + 'encoder_pos_conv_kernel': 128, + 'encoder_pos_conv_groups': 16, + 'encoder_num_layers': 12, + 'encoder_num_heads': 12, + 'encoder_attention_dropout': 0.1, + 'encoder_ff_interm_features': 3072, + 'encoder_ff_interm_dropout': 0.0, + 'encoder_dropout': 0.1, + 'encoder_layer_norm_first': False, + 'encoder_layer_drop': 0.05, + "aux_num_out": None, + }, + _sample_rate=16000, +) +WAV2VEC2_BASE.__doc__ = """wav2vec 2.0 model with "Base" configuration. + +Pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset [:footcite:`7178964`] +(the combination of "train-clean-100", "train-clean-360", and "train-other-500"). +Not fine-tuned. + +Originally published by the authors of *wav2vec 2.0* [:footcite:`baevski2020wav2vec`] under MIT License and +redistributed with the same license. +[`License `__, +`Source `__] + +Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage. +""" # noqa: E501 + +WAV2VEC2_ASR_BASE_10M = Wav2Vec2ASRBundle( + _path='wav2vec2_fairseq_base_ls960_asr_ll10m.pth', + _params={ + 'extractor_mode': 'group_norm', + 'extractor_conv_layer_config': [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ], + 'extractor_conv_bias': False, + 'encoder_embed_dim': 768, + 'encoder_projection_dropout': 0.1, + 'encoder_pos_conv_kernel': 128, + 'encoder_pos_conv_groups': 16, + 'encoder_num_layers': 12, + 'encoder_num_heads': 12, + 'encoder_attention_dropout': 0.1, + 'encoder_ff_interm_features': 3072, + 'encoder_ff_interm_dropout': 0.0, + 'encoder_dropout': 0.1, + 'encoder_layer_norm_first': False, + 'encoder_layer_drop': 0.05, + "aux_num_out": 32, + }, + _labels=_get_labels(), + _sample_rate=16000, +) +WAV2VEC2_ASR_BASE_10M.__doc__ = """Build "base" wav2vec2 model with an extra linear module + +Pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset [:footcite:`7178964`] +(the combination of "train-clean-100", "train-clean-360", and "train-other-500"), and +fine-tuned for ASR on 10 minutes of transcribed audio from *Libri-Light* dataset +[:footcite:`librilight`] ("train-10min" subset). + +Originally published by the authors of *wav2vec 2.0* [:footcite:`baevski2020wav2vec`] under MIT License and +redistributed with the same license. +[`License `__, +`Source `__] + +Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. +""" # noqa: E501 + +WAV2VEC2_ASR_BASE_100H = Wav2Vec2ASRBundle( + 'wav2vec2_fairseq_base_ls960_asr_ls100.pth', + { + 'extractor_mode': 'group_norm', + 'extractor_conv_layer_config': [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ], + 'extractor_conv_bias': False, + 'encoder_embed_dim': 768, + 'encoder_projection_dropout': 0.1, + 'encoder_pos_conv_kernel': 128, + 'encoder_pos_conv_groups': 16, + 'encoder_num_layers': 12, + 'encoder_num_heads': 12, + 'encoder_attention_dropout': 0.1, + 'encoder_ff_interm_features': 3072, + 'encoder_ff_interm_dropout': 0.0, + 'encoder_dropout': 0.1, + 'encoder_layer_norm_first': False, + 'encoder_layer_drop': 0.05, + "aux_num_out": 32, + }, + _labels=_get_labels(), + _sample_rate=16000, +) + +WAV2VEC2_ASR_BASE_100H.__doc__ = """Build "base" wav2vec2 model with an extra linear module + +Pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset [:footcite:`7178964`] +(the combination of "train-clean-100", "train-clean-360", and "train-other-500"), and +fine-tuned for ASR on 100 hours of transcribed audio from "train-clean-100" subset. + +Originally published by the authors of *wav2vec 2.0* [:footcite:`baevski2020wav2vec`] under MIT License and +redistributed with the same license. +[`License `__, +`Source `__] + +Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. +""" # noqa: E501 + +WAV2VEC2_ASR_BASE_960H = Wav2Vec2ASRBundle( + 'wav2vec2_fairseq_base_ls960_asr_ls960.pth', + { + "extractor_mode": "group_norm", + "extractor_conv_layer_config": [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ], + "extractor_conv_bias": False, + "encoder_embed_dim": 768, + "encoder_projection_dropout": 0.1, + "encoder_pos_conv_kernel": 128, + "encoder_pos_conv_groups": 16, + "encoder_num_layers": 12, + "encoder_num_heads": 12, + "encoder_attention_dropout": 0.1, + "encoder_ff_interm_features": 3072, + "encoder_ff_interm_dropout": 0.0, + "encoder_dropout": 0.1, + "encoder_layer_norm_first": False, + "encoder_layer_drop": 0.05, + "aux_num_out": 32, + }, + _labels=_get_labels(), + _sample_rate=16000, +) +WAV2VEC2_ASR_BASE_960H.__doc__ = """Build "base" wav2vec2 model with an extra linear module + +Pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset [:footcite:`7178964`] +(the combination of "train-clean-100", "train-clean-360", and "train-other-500"), and +fine-tuned for ASR on the same audio with the corresponding transcripts. + +Originally published by the authors of *wav2vec 2.0* [:footcite:`baevski2020wav2vec`] under MIT License and +redistributed with the same license. +[`License `__, +`Source `__] + +Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. +""" # noqa: E501 + +WAV2VEC2_LARGE = Wav2Vec2Bundle( + 'wav2vec2_fairseq_large_ls960.pth', + { + "extractor_mode": "group_norm", + "extractor_conv_layer_config": [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ], + "extractor_conv_bias": False, + "encoder_embed_dim": 1024, + "encoder_projection_dropout": 0.1, + "encoder_pos_conv_kernel": 128, + "encoder_pos_conv_groups": 16, + "encoder_num_layers": 24, + "encoder_num_heads": 16, + "encoder_attention_dropout": 0.1, + "encoder_ff_interm_features": 4096, + "encoder_ff_interm_dropout": 0.0, + "encoder_dropout": 0.0, + "encoder_layer_norm_first": False, + "encoder_layer_drop": 0.2, + "aux_num_out": None, + }, + _sample_rate=16000, +) +WAV2VEC2_LARGE.__doc__ = """Build "large" wav2vec2 model. + +Pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset [:footcite:`7178964`] +(the combination of "train-clean-100", "train-clean-360", and "train-other-500"). +Not fine-tuned. + +Originally published by the authors of *wav2vec 2.0* [:footcite:`baevski2020wav2vec`] under MIT License and +redistributed with the same license. +[`License `__, +`Source `__] + +Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage. +""" # noqa: E501 + +WAV2VEC2_ASR_LARGE_10M = Wav2Vec2ASRBundle( + 'wav2vec2_fairseq_large_ls960_asr_ll10m.pth', + { + "extractor_mode": "group_norm", + "extractor_conv_layer_config": [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ], + "extractor_conv_bias": False, + "encoder_embed_dim": 1024, + "encoder_projection_dropout": 0.1, + "encoder_pos_conv_kernel": 128, + "encoder_pos_conv_groups": 16, + "encoder_num_layers": 24, + "encoder_num_heads": 16, + "encoder_attention_dropout": 0.1, + "encoder_ff_interm_features": 4096, + "encoder_ff_interm_dropout": 0.0, + "encoder_dropout": 0.0, + "encoder_layer_norm_first": False, + "encoder_layer_drop": 0.2, + "aux_num_out": 32, + }, + _labels=_get_labels(), + _sample_rate=16000, +) +WAV2VEC2_ASR_LARGE_10M.__doc__ = """Build "large" wav2vec2 model with an extra linear module + +Pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset [:footcite:`7178964`] +(the combination of "train-clean-100", "train-clean-360", and "train-other-500"), and +fine-tuned for ASR on 10 minutes of transcribed audio from *Libri-Light* dataset +[:footcite:`librilight`] ("train-10min" subset). + +Originally published by the authors of *wav2vec 2.0* [:footcite:`baevski2020wav2vec`] under MIT License and +redistributed with the same license. +[`License `__, +`Source `__] + +Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. +""" # noqa: E501 + +WAV2VEC2_ASR_LARGE_100H = Wav2Vec2ASRBundle( + 'wav2vec2_fairseq_large_ls960_asr_ls100.pth', + { + "extractor_mode": "group_norm", + "extractor_conv_layer_config": [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ], + "extractor_conv_bias": False, + "encoder_embed_dim": 1024, + "encoder_projection_dropout": 0.1, + "encoder_pos_conv_kernel": 128, + "encoder_pos_conv_groups": 16, + "encoder_num_layers": 24, + "encoder_num_heads": 16, + "encoder_attention_dropout": 0.1, + "encoder_ff_interm_features": 4096, + "encoder_ff_interm_dropout": 0.0, + "encoder_dropout": 0.0, + "encoder_layer_norm_first": False, + "encoder_layer_drop": 0.2, + "aux_num_out": 32, + }, + _labels=_get_labels(), + _sample_rate=16000, +) +WAV2VEC2_ASR_LARGE_100H.__doc__ = """Build "large" wav2vec2 model with an extra linear module + +Pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset [:footcite:`7178964`] +(the combination of "train-clean-100", "train-clean-360", and "train-other-500"), and +fine-tuned for ASR on 100 hours of transcribed audio from +the same dataset ("train-clean-100" subset). + +Originally published by the authors of *wav2vec 2.0* [:footcite:`baevski2020wav2vec`] under MIT License and +redistributed with the same license. +[`License `__, +`Source `__] + +Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. +""" # noqa: E501 + +WAV2VEC2_ASR_LARGE_960H = Wav2Vec2ASRBundle( + 'wav2vec2_fairseq_large_ls960_asr_ls960.pth', + { + "extractor_mode": "group_norm", + "extractor_conv_layer_config": [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ], + "extractor_conv_bias": False, + "encoder_embed_dim": 1024, + "encoder_projection_dropout": 0.1, + "encoder_pos_conv_kernel": 128, + "encoder_pos_conv_groups": 16, + "encoder_num_layers": 24, + "encoder_num_heads": 16, + "encoder_attention_dropout": 0.1, + "encoder_ff_interm_features": 4096, + "encoder_ff_interm_dropout": 0.0, + "encoder_dropout": 0.0, + "encoder_layer_norm_first": False, + "encoder_layer_drop": 0.2, + "aux_num_out": 32, + }, + _labels=_get_labels(), + _sample_rate=16000, +) +WAV2VEC2_ASR_LARGE_960H.__doc__ = """Build "large" wav2vec2 model with an extra linear module + +Pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset [:footcite:`7178964`] +(the combination of "train-clean-100", "train-clean-360", and "train-other-500"), and +fine-tuned for ASR on the same audio with the corresponding transcripts. + +Originally published by the authors of *wav2vec 2.0* [:footcite:`baevski2020wav2vec`] under MIT License and +redistributed with the same license. +[`License `__, +`Source `__] + +Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. +""" # noqa: E501 + +WAV2VEC2_LARGE_LV60K = Wav2Vec2Bundle( + 'wav2vec2_fairseq_large_lv60k.pth', + { + "extractor_mode": "layer_norm", + "extractor_conv_layer_config": [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ], + "extractor_conv_bias": True, + "encoder_embed_dim": 1024, + "encoder_projection_dropout": 0.1, + "encoder_pos_conv_kernel": 128, + "encoder_pos_conv_groups": 16, + "encoder_num_layers": 24, + "encoder_num_heads": 16, + "encoder_attention_dropout": 0.1, + "encoder_ff_interm_features": 4096, + "encoder_ff_interm_dropout": 0.0, + "encoder_dropout": 0.0, + "encoder_layer_norm_first": True, + "encoder_layer_drop": 0.0, + "aux_num_out": None, + }, + _sample_rate=16000, +) +WAV2VEC2_LARGE_LV60K.__doc__ = """Build "large-lv60k" wav2vec2 model. + +Pre-trained on 60,000 hours of unlabeled audio from +*Libri-Light* dataset [:footcite:`librilight`]. +Not fine-tuned. + +Originally published by the authors of *wav2vec 2.0* [:footcite:`baevski2020wav2vec`] under MIT License and +redistributed with the same license. +[`License `__, +`Source `__] + +Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage. +""" # noqa: E501 + +WAV2VEC2_ASR_LARGE_LV60K_10M = Wav2Vec2ASRBundle( + 'wav2vec2_fairseq_large_lv60k_asr_ll10m.pth', + { + "extractor_mode": "layer_norm", + "extractor_conv_layer_config": [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ], + "extractor_conv_bias": True, + "encoder_embed_dim": 1024, + "encoder_projection_dropout": 0.1, + "encoder_pos_conv_kernel": 128, + "encoder_pos_conv_groups": 16, + "encoder_num_layers": 24, + "encoder_num_heads": 16, + "encoder_attention_dropout": 0.1, + "encoder_ff_interm_features": 4096, + "encoder_ff_interm_dropout": 0.0, + "encoder_dropout": 0.0, + "encoder_layer_norm_first": True, + "encoder_layer_drop": 0.0, + "aux_num_out": 32, + }, + _labels=_get_labels(), + _sample_rate=16000, +) +WAV2VEC2_ASR_LARGE_LV60K_10M.__doc__ = """Build "large-lv60k" wav2vec2 model with an extra linear module + +Pre-trained on 60,000 hours of unlabeled audio from +*Libri-Light* dataset [:footcite:`librilight`], and +fine-tuned for ASR on 10 minutes of transcribed audio from +the same dataset ("train-10min" subset). + +Originally published by the authors of *wav2vec 2.0* [:footcite:`baevski2020wav2vec`] under MIT License and +redistributed with the same license. +[`License `__, +`Source `__] + +Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. +""" # noqa: E501 + +WAV2VEC2_ASR_LARGE_LV60K_100H = Wav2Vec2ASRBundle( + 'wav2vec2_fairseq_large_lv60k_asr_ls100.pth', + { + "extractor_mode": "layer_norm", + "extractor_conv_layer_config": [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ], + "extractor_conv_bias": True, + "encoder_embed_dim": 1024, + "encoder_projection_dropout": 0.1, + "encoder_pos_conv_kernel": 128, + "encoder_pos_conv_groups": 16, + "encoder_num_layers": 24, + "encoder_num_heads": 16, + "encoder_attention_dropout": 0.1, + "encoder_ff_interm_features": 4096, + "encoder_ff_interm_dropout": 0.0, + "encoder_dropout": 0.0, + "encoder_layer_norm_first": True, + "encoder_layer_drop": 0.0, + "aux_num_out": 32, + }, + _labels=_get_labels(), + _sample_rate=16000, +) +WAV2VEC2_ASR_LARGE_LV60K_100H.__doc__ = """Build "large-lv60k" wav2vec2 model with an extra linear module + +Pre-trained on 60,000 hours of unlabeled audio from +*Libri-Light* dataset [:footcite:`librilight`], and +fine-tuned for ASR on 100 hours of transcribed audio from +*LibriSpeech* dataset [:footcite:`7178964`] ("train-clean-100" subset). + +Originally published by the authors of *wav2vec 2.0* [:footcite:`baevski2020wav2vec`] under MIT License and +redistributed with the same license. +[`License `__, +`Source `__] + +Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. +""" # noqa: E501 + +WAV2VEC2_ASR_LARGE_LV60K_960H = Wav2Vec2ASRBundle( + 'wav2vec2_fairseq_large_lv60k_asr_ls960.pth', + { + "extractor_mode": "layer_norm", + "extractor_conv_layer_config": [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ], + "extractor_conv_bias": True, + "encoder_embed_dim": 1024, + "encoder_projection_dropout": 0.1, + "encoder_pos_conv_kernel": 128, + "encoder_pos_conv_groups": 16, + "encoder_num_layers": 24, + "encoder_num_heads": 16, + "encoder_attention_dropout": 0.1, + "encoder_ff_interm_features": 4096, + "encoder_ff_interm_dropout": 0.0, + "encoder_dropout": 0.0, + "encoder_layer_norm_first": True, + "encoder_layer_drop": 0.0, + "aux_num_out": 32, + }, + _labels=_get_labels(), + _sample_rate=16000, +) +WAV2VEC2_ASR_LARGE_LV60K_960H.__doc__ = """Build "large-lv60k" wav2vec2 model with an extra linear module + +Pre-trained on 60,000 hours of unlabeled audio from *Libri-Light* +[:footcite:`librilight`] dataset, and +fine-tuned for ASR on 960 hours of transcribed audio from +*LibriSpeech* dataset [:footcite:`7178964`] +(the combination of "train-clean-100", "train-clean-360", and "train-other-500"). + +Originally published by the authors of *wav2vec 2.0* [:footcite:`baevski2020wav2vec`] under MIT License and +redistributed with the same license. +[`License `__, +`Source `__] + +Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. +""" # noqa: E501 + +WAV2VEC2_XLSR53 = Wav2Vec2Bundle( + 'wav2vec2_fairseq_large_xlsr53.pth', + { + "extractor_mode": "layer_norm", + "extractor_conv_layer_config": [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ], + "extractor_conv_bias": True, + "encoder_embed_dim": 1024, + "encoder_projection_dropout": 0.0, + "encoder_pos_conv_kernel": 128, + "encoder_pos_conv_groups": 16, + "encoder_num_layers": 24, + "encoder_num_heads": 16, + "encoder_attention_dropout": 0.0, + "encoder_ff_interm_features": 4096, + "encoder_ff_interm_dropout": 0.0, + "encoder_dropout": 0.0, + "encoder_layer_norm_first": True, + "encoder_layer_drop": 0.0, + "aux_num_out": None, + }, + _sample_rate=16000, +) +WAV2VEC2_XLSR53.__doc__ = """wav2vec 2.0 model with "Base" configuration. + +Trained on 56,000 hours of unlabeled audio from multiple datasets ( +*Multilingual LibriSpeech* [:footcite:`Pratap_2020`], +*CommonVoice* [:footcite:`ardila2020common`] and +*BABEL* [:footcite:`Gales2014SpeechRA`]). +Not fine-tuned. + +Originally published by the authors of +*Unsupervised Cross-lingual Representation Learning for Speech Recognition* +[:footcite:`conneau2020unsupervised`] under MIT License and redistributed with the same license. +[`License `__, +`Source `__] + +Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage. +""" # noqa: E501 + +HUBERT_BASE = Wav2Vec2Bundle( + 'hubert_fairseq_base_ls960.pth', + { + 'extractor_mode': 'group_norm', + 'extractor_conv_layer_config': [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ], + 'extractor_conv_bias': False, + 'encoder_embed_dim': 768, + 'encoder_projection_dropout': 0.1, + 'encoder_pos_conv_kernel': 128, + 'encoder_pos_conv_groups': 16, + 'encoder_num_layers': 12, + 'encoder_num_heads': 12, + 'encoder_attention_dropout': 0.1, + 'encoder_ff_interm_features': 3072, + 'encoder_ff_interm_dropout': 0.0, + 'encoder_dropout': 0.1, + 'encoder_layer_norm_first': False, + 'encoder_layer_drop': 0.05, + 'aux_num_out': None, + }, + _sample_rate=16000, +) +HUBERT_BASE.__doc__ = """HuBERT model with "Base" configuration. + +Pre-trained on 960 hours of unlabeled audio from *LibriSpeech* dataset [:footcite:`7178964`] +(the combination of "train-clean-100", "train-clean-360", and "train-other-500"). +Not fine-tuned. + +Originally published by the authors of *HuBERT* [:footcite:`hsu2021hubert`] under MIT License and +redistributed with the same license. +[`License `__, +`Source `__] + +Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage. +""" # noqa: E501 + +HUBERT_LARGE = Wav2Vec2Bundle( + 'hubert_fairseq_large_ll60k.pth', + { + 'extractor_mode': 'layer_norm', + 'extractor_conv_layer_config': [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ], + 'extractor_conv_bias': False, + 'encoder_embed_dim': 1024, + 'encoder_projection_dropout': 0.0, + 'encoder_pos_conv_kernel': 128, + 'encoder_pos_conv_groups': 16, + 'encoder_num_layers': 24, + 'encoder_num_heads': 16, + 'encoder_attention_dropout': 0.0, + 'encoder_ff_interm_features': 4096, + 'encoder_ff_interm_dropout': 0.0, + 'encoder_dropout': 0.0, + 'encoder_layer_norm_first': True, + 'encoder_layer_drop': 0.0, + 'aux_num_out': None, + }, + _sample_rate=16000, +) +HUBERT_LARGE.__doc__ = """HuBERT model with "Large" configuration. + +Pre-trained on 60,000 hours of unlabeled audio from +*Libri-Light* dataset [:footcite:`librilight`]. +Not fine-tuned. + +Originally published by the authors of *HuBERT* [:footcite:`hsu2021hubert`] under MIT License and +redistributed with the same license. +[`License `__, +`Source `__] + +Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage. +""" # noqa: E501 + +HUBERT_XLARGE = Wav2Vec2Bundle( + 'hubert_fairseq_xlarge_ll60k.pth', + { + 'extractor_mode': 'layer_norm', + 'extractor_conv_layer_config': [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ], + 'extractor_conv_bias': False, + 'encoder_embed_dim': 1280, + 'encoder_projection_dropout': 0.0, + 'encoder_pos_conv_kernel': 128, + 'encoder_pos_conv_groups': 16, + 'encoder_num_layers': 48, + 'encoder_num_heads': 16, + 'encoder_attention_dropout': 0.0, + 'encoder_ff_interm_features': 5120, + 'encoder_ff_interm_dropout': 0.0, + 'encoder_dropout': 0.0, + 'encoder_layer_norm_first': True, + 'encoder_layer_drop': 0.0, + 'aux_num_out': None, + }, + _sample_rate=16000, +) +HUBERT_XLARGE.__doc__ = """HuBERT model with "Extra Large" configuration. + +Pre-trained on 60,000 hours of unlabeled audio from +*Libri-Light* dataset [:footcite:`librilight`]. +Not fine-tuned. + +Originally published by the authors of *HuBERT* [:footcite:`hsu2021hubert`] under MIT License and +redistributed with the same license. +[`License `__, +`Source `__] + +Please refer to :func:`torchaudio.pipelines.Wav2Vec2Bundle` for the usage. +""" # noqa: E501 + +HUBERT_ASR_LARGE = Wav2Vec2ASRBundle( + 'hubert_fairseq_large_ll60k_asr_ls960.pth', + { + 'extractor_mode': 'layer_norm', + 'extractor_conv_layer_config': [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ], + 'extractor_conv_bias': False, + 'encoder_embed_dim': 1024, + 'encoder_projection_dropout': 0.0, + 'encoder_pos_conv_kernel': 128, + 'encoder_pos_conv_groups': 16, + 'encoder_num_layers': 24, + 'encoder_num_heads': 16, + 'encoder_attention_dropout': 0.0, + 'encoder_ff_interm_features': 4096, + 'encoder_ff_interm_dropout': 0.1, + 'encoder_dropout': 0.0, + 'encoder_layer_norm_first': True, + 'encoder_layer_drop': 0.1, + 'aux_num_out': 32, + }, + _labels=_get_labels(), + _sample_rate=16000, +) +HUBERT_ASR_LARGE.__doc__ = """HuBERT model with "Large" configuration. + +Pre-trained on 60,000 hours of unlabeled audio from +*Libri-Light* dataset [:footcite:`librilight`], and +fine-tuned for ASR on 960 hours of transcribed audio from +*LibriSpeech* dataset [:footcite:`7178964`] +(the combination of "train-clean-100", "train-clean-360", and "train-other-500"). + +Originally published by the authors of *HuBERT* [:footcite:`hsu2021hubert`] under MIT License and +redistributed with the same license. +[`License `__, +`Source `__] + +Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. +""" # noqa: E501 + +HUBERT_ASR_XLARGE = Wav2Vec2ASRBundle( + 'hubert_fairseq_xlarge_ll60k_asr_ls960.pth', + { + 'extractor_mode': 'layer_norm', + 'extractor_conv_layer_config': [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ], + 'extractor_conv_bias': False, + 'encoder_embed_dim': 1280, + 'encoder_projection_dropout': 0.0, + 'encoder_pos_conv_kernel': 128, + 'encoder_pos_conv_groups': 16, + 'encoder_num_layers': 48, + 'encoder_num_heads': 16, + 'encoder_attention_dropout': 0.0, + 'encoder_ff_interm_features': 5120, + 'encoder_ff_interm_dropout': 0.1, + 'encoder_dropout': 0.0, + 'encoder_layer_norm_first': True, + 'encoder_layer_drop': 0.1, + 'aux_num_out': 32, + }, + _labels=_get_labels(), + _sample_rate=16000, +) +HUBERT_ASR_XLARGE.__doc__ = """HuBERT model with "Extra Large" configuration. + +Pre-trained on 60,000 hours of unlabeled audio from +*Libri-Light* dataset [:footcite:`librilight`], and +fine-tuned for ASR on 960 hours of transcribed audio from +*LibriSpeech* dataset [:footcite:`7178964`] +(the combination of "train-clean-100", "train-clean-360", and "train-other-500"). + +Originally published by the authors of *HuBERT* [:footcite:`hsu2021hubert`] under MIT License and +redistributed with the same license. +[`License `__, +`Source `__] + +Please refer to :func:`torchaudio.pipelines.Wav2Vec2ASRBundle` for the usage. +""" # noqa: E501 diff --git a/torchaudio/prototype/__init__.py b/torchaudio/prototype/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/torchaudio/sox_effects/__init__.py b/torchaudio/sox_effects/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3c46cebdd2a6ea5cd0ffc53f0b5d15054c5658e5 --- /dev/null +++ b/torchaudio/sox_effects/__init__.py @@ -0,0 +1,22 @@ +from torchaudio._internal import module_utils as _mod_utils +from .sox_effects import ( + init_sox_effects, + shutdown_sox_effects, + effect_names, + apply_effects_tensor, + apply_effects_file, +) + + +if _mod_utils.is_sox_available(): + import atexit + init_sox_effects() + atexit.register(shutdown_sox_effects) + +__all__ = [ + 'init_sox_effects', + 'shutdown_sox_effects', + 'effect_names', + 'apply_effects_tensor', + 'apply_effects_file', +] diff --git a/torchaudio/sox_effects/sox_effects.py b/torchaudio/sox_effects/sox_effects.py new file mode 100644 index 0000000000000000000000000000000000000000..c17c3568594fae6af5bd49b31ba3d0798a327c8a --- /dev/null +++ b/torchaudio/sox_effects/sox_effects.py @@ -0,0 +1,273 @@ +import os +from typing import List, Tuple, Optional + +import torch + +import torchaudio +from torchaudio._internal import module_utils as _mod_utils +from torchaudio.utils.sox_utils import list_effects + + +@_mod_utils.requires_sox() +def init_sox_effects(): + """Initialize resources required to use sox effects. + + Note: + You do not need to call this function manually. It is called automatically. + + Once initialized, you do not need to call this function again across the multiple uses of + sox effects though it is safe to do so as long as :func:`shutdown_sox_effects` is not called yet. + Once :func:`shutdown_sox_effects` is called, you can no longer use SoX effects and initializing + again will result in error. + """ + torch.ops.torchaudio.sox_effects_initialize_sox_effects() + + +@_mod_utils.requires_sox() +def shutdown_sox_effects(): + """Clean up resources required to use sox effects. + + Note: + You do not need to call this function manually. It is called automatically. + + It is safe to call this function multiple times. + Once :py:func:`shutdown_sox_effects` is called, you can no longer use SoX effects and + initializing again will result in error. + """ + torch.ops.torchaudio.sox_effects_shutdown_sox_effects() + + +@_mod_utils.requires_sox() +def effect_names() -> List[str]: + """Gets list of valid sox effect names + + Returns: + List[str]: list of available effect names. + + Example + >>> torchaudio.sox_effects.effect_names() + ['allpass', 'band', 'bandpass', ... ] + """ + return list(list_effects().keys()) + + +@_mod_utils.requires_sox() +def apply_effects_tensor( + tensor: torch.Tensor, + sample_rate: int, + effects: List[List[str]], + channels_first: bool = True, +) -> Tuple[torch.Tensor, int]: + """Apply sox effects to given Tensor + + Note: + This function only works on CPU Tensors. + This function works in the way very similar to ``sox`` command, however there are slight + differences. For example, ``sox`` command adds certain effects automatically (such as + ``rate`` effect after ``speed`` and ``pitch`` and other effects), but this function does + only applies the given effects. (Therefore, to actually apply ``speed`` effect, you also + need to give ``rate`` effect with desired sampling rate.). + + Args: + tensor (torch.Tensor): Input 2D CPU Tensor. + sample_rate (int): Sample rate + effects (List[List[str]]): List of effects. + channels_first (bool, optional): Indicates if the input Tensor's dimension is + `[channels, time]` or `[time, channels]` + + Returns: + (Tensor, int): Resulting Tensor and sample rate. + The resulting Tensor has the same ``dtype`` as the input Tensor, and + the same channels order. The shape of the Tensor can be different based on the + effects applied. Sample rate can also be different based on the effects applied. + + Example - Basic usage + >>> + >>> # Defines the effects to apply + >>> effects = [ + ... ['gain', '-n'], # normalises to 0dB + ... ['pitch', '5'], # 5 cent pitch shift + ... ['rate', '8000'], # resample to 8000 Hz + ... ] + >>> + >>> # Generate pseudo wave: + >>> # normalized, channels first, 2ch, sampling rate 16000, 1 second + >>> sample_rate = 16000 + >>> waveform = 2 * torch.rand([2, sample_rate * 1]) - 1 + >>> waveform.shape + torch.Size([2, 16000]) + >>> waveform + tensor([[ 0.3138, 0.7620, -0.9019, ..., -0.7495, -0.4935, 0.5442], + [-0.0832, 0.0061, 0.8233, ..., -0.5176, -0.9140, -0.2434]]) + >>> + >>> # Apply effects + >>> waveform, sample_rate = apply_effects_tensor( + ... wave_form, sample_rate, effects, channels_first=True) + >>> + >>> # Check the result + >>> # The new waveform is sampling rate 8000, 1 second. + >>> # normalization and channel order are preserved + >>> waveform.shape + torch.Size([2, 8000]) + >>> waveform + tensor([[ 0.5054, -0.5518, -0.4800, ..., -0.0076, 0.0096, -0.0110], + [ 0.1331, 0.0436, -0.3783, ..., -0.0035, 0.0012, 0.0008]]) + >>> sample_rate + 8000 + + Example - Torchscript-able transform + >>> + >>> # Use `apply_effects_tensor` in `torch.nn.Module` and dump it to file, + >>> # then run sox effect via Torchscript runtime. + >>> + >>> class SoxEffectTransform(torch.nn.Module): + ... effects: List[List[str]] + ... + ... def __init__(self, effects: List[List[str]]): + ... super().__init__() + ... self.effects = effects + ... + ... def forward(self, tensor: torch.Tensor, sample_rate: int): + ... return sox_effects.apply_effects_tensor( + ... tensor, sample_rate, self.effects) + ... + ... + >>> # Create transform object + >>> effects = [ + ... ["lowpass", "-1", "300"], # apply single-pole lowpass filter + ... ["rate", "8000"], # change sample rate to 8000 + ... ] + >>> transform = SoxEffectTensorTransform(effects, input_sample_rate) + >>> + >>> # Dump it to file and load + >>> path = 'sox_effect.zip' + >>> torch.jit.script(trans).save(path) + >>> transform = torch.jit.load(path) + >>> + >>>> # Run transform + >>> waveform, input_sample_rate = torchaudio.load("input.wav") + >>> waveform, sample_rate = transform(waveform, input_sample_rate) + >>> assert sample_rate == 8000 + """ + return torch.ops.torchaudio.sox_effects_apply_effects_tensor( + tensor, sample_rate, effects, channels_first) + + +@_mod_utils.requires_sox() +def apply_effects_file( + path: str, + effects: List[List[str]], + normalize: bool = True, + channels_first: bool = True, + format: Optional[str] = None, +) -> Tuple[torch.Tensor, int]: + """Apply sox effects to the audio file and load the resulting data as Tensor + + Note: + This function works in the way very similar to ``sox`` command, however there are slight + differences. For example, ``sox`` commnad adds certain effects automatically (such as + ``rate`` effect after ``speed``, ``pitch`` etc), but this function only applies the given + effects. Therefore, to actually apply ``speed`` effect, you also need to give ``rate`` + effect with desired sampling rate, because internally, ``speed`` effects only alter sampling + rate and leave samples untouched. + + Args: + path (path-like object or file-like object): + Source of audio data. When the function is not compiled by TorchScript, + (e.g. ``torch.jit.script``), the following types are accepted: + + * ``path-like``: file path + * ``file-like``: Object with ``read(size: int) -> bytes`` method, + which returns byte string of at most ``size`` length. + + When the function is compiled by TorchScript, only ``str`` type is allowed. + + Note: This argument is intentionally annotated as ``str`` only for + TorchScript compiler compatibility. + effects (List[List[str]]): List of effects. + normalize (bool, optional): + When ``True``, this function always return ``float32``, and sample values are + normalized to ``[-1.0, 1.0]``. + If input file is integer WAV, giving ``False`` will change the resulting Tensor type to + integer type. This argument has no effect for formats other + than integer WAV type. + channels_first (bool, optional): When True, the returned Tensor has dimension `[channel, time]`. + Otherwise, the returned Tensor's dimension is `[time, channel]`. + format (str or None, optional): + Override the format detection with the given format. + Providing the argument might help when libsox can not infer the format + from header or extension, + + Returns: + (Tensor, int): Resulting Tensor and sample rate. + If ``normalize=True``, the resulting Tensor is always ``float32`` type. + If ``normalize=False`` and the input audio file is of integer WAV file, then the + resulting Tensor has corresponding integer type. (Note 24 bit integer type is not supported) + If ``channels_first=True``, the resulting Tensor has dimension `[channel, time]`, + otherwise `[time, channel]`. + + Example - Basic usage + >>> + >>> # Defines the effects to apply + >>> effects = [ + ... ['gain', '-n'], # normalises to 0dB + ... ['pitch', '5'], # 5 cent pitch shift + ... ['rate', '8000'], # resample to 8000 Hz + ... ] + >>> + >>> # Apply effects and load data with channels_first=True + >>> waveform, sample_rate = apply_effects_file("data.wav", effects, channels_first=True) + >>> + >>> # Check the result + >>> waveform.shape + torch.Size([2, 8000]) + >>> waveform + tensor([[ 5.1151e-03, 1.8073e-02, 2.2188e-02, ..., 1.0431e-07, + -1.4761e-07, 1.8114e-07], + [-2.6924e-03, 2.1860e-03, 1.0650e-02, ..., 6.4122e-07, + -5.6159e-07, 4.8103e-07]]) + >>> sample_rate + 8000 + + Example - Apply random speed perturbation to dataset + >>> + >>> # Load data from file, apply random speed perturbation + >>> class RandomPerturbationFile(torch.utils.data.Dataset): + ... \"\"\"Given flist, apply random speed perturbation + ... + ... Suppose all the input files are at least one second long. + ... \"\"\" + ... def __init__(self, flist: List[str], sample_rate: int): + ... super().__init__() + ... self.flist = flist + ... self.sample_rate = sample_rate + ... + ... def __getitem__(self, index): + ... speed = 0.5 + 1.5 * random.randn() + ... effects = [ + ... ['gain', '-n', '-10'], # apply 10 db attenuation + ... ['remix', '-'], # merge all the channels + ... ['speed', f'{speed:.5f}'], # duration is now 0.5 ~ 2.0 seconds. + ... ['rate', f'{self.sample_rate}'], + ... ['pad', '0', '1.5'], # add 1.5 seconds silence at the end + ... ['trim', '0', '2'], # get the first 2 seconds + ... ] + ... waveform, _ = torchaudio.sox_effects.apply_effects_file( + ... self.flist[index], effects) + ... return waveform + ... + ... def __len__(self): + ... return len(self.flist) + ... + >>> dataset = RandomPerturbationFile(file_list, sample_rate=8000) + >>> loader = torch.utils.data.DataLoader(dataset, batch_size=32) + >>> for batch in loader: + >>> pass + """ + if not torch.jit.is_scripting(): + if hasattr(path, 'read'): + return torchaudio._torchaudio.apply_effects_fileobj( + path, effects, normalize, channels_first, format) + path = os.fspath(path) + return torch.ops.torchaudio.sox_effects_apply_effects_file( + path, effects, normalize, channels_first, format) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..eae3a50e736167620b0761fb655fad2a59d309d8 --- /dev/null +++ b/torchaudio/transforms.py @@ -0,0 +1,2050 @@ +# -*- coding: utf-8 -*- + +import math +import warnings +from typing import Callable, Optional + +import torch +from torch import Tensor +from torchaudio import functional as F + +from .functional.functional import ( + _get_sinc_resample_kernel, + _apply_sinc_resample_kernel, +) + +__all__ = [ + 'Spectrogram', + 'InverseSpectrogram', + 'GriffinLim', + 'AmplitudeToDB', + 'MelScale', + 'InverseMelScale', + 'MelSpectrogram', + 'MFCC', + 'LFCC', + 'MuLawEncoding', + 'MuLawDecoding', + 'Resample', + 'ComplexNorm', + 'TimeStretch', + 'Fade', + 'FrequencyMasking', + 'TimeMasking', + 'SlidingWindowCmn', + 'Vad', + 'SpectralCentroid', + 'Vol', + 'ComputeDeltas', + 'PitchShift', + 'RNNTLoss', + 'PSD', + 'MVDR', +] + + +class Spectrogram(torch.nn.Module): + r"""Create a spectrogram from a audio signal. + + Args: + n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``) + win_length (int or None, optional): Window size. (Default: ``n_fft``) + hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``) + pad (int, optional): Two sided padding of signal. (Default: ``0``) + window_fn (Callable[..., Tensor], optional): A function to create a window tensor + that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) + power (float or None, optional): Exponent for the magnitude spectrogram, + (must be > 0) e.g., 1 for energy, 2 for power, etc. + If None, then the complex spectrum is returned instead. (Default: ``2``) + normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``) + wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``) + center (bool, optional): whether to pad :attr:`waveform` on both sides so + that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. + (Default: ``True``) + pad_mode (string, optional): controls the padding method used when + :attr:`center` is ``True``. (Default: ``"reflect"``) + onesided (bool, optional): controls whether to return half of results to + avoid redundancy (Default: ``True``) + return_complex (bool, optional): + Indicates whether the resulting complex-valued Tensor should be represented with + native complex dtype, such as `torch.cfloat` and `torch.cdouble`, or real dtype + mimicking complex value with an extra dimension for real and imaginary parts. + (See also ``torch.view_as_real``.) + This argument is only effective when ``power=None``. It is ignored for + cases where ``power`` is a number as in those cases, the returned tensor is + power spectrogram, which is a real-valued tensor. + + Example + >>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True) + >>> transform = torchaudio.transforms.Spectrogram(n_fft=800) + >>> spectrogram = transform(waveform) + + """ + __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized'] + + def __init__(self, + n_fft: int = 400, + win_length: Optional[int] = None, + hop_length: Optional[int] = None, + pad: int = 0, + window_fn: Callable[..., Tensor] = torch.hann_window, + power: Optional[float] = 2., + normalized: bool = False, + wkwargs: Optional[dict] = None, + center: bool = True, + pad_mode: str = "reflect", + onesided: bool = True, + return_complex: bool = True) -> None: + super(Spectrogram, self).__init__() + self.n_fft = n_fft + # number of FFT bins. the returned STFT result will have n_fft // 2 + 1 + # number of frequencies due to onesided=True in torch.stft + self.win_length = win_length if win_length is not None else n_fft + self.hop_length = hop_length if hop_length is not None else self.win_length // 2 + window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) + self.register_buffer('window', window) + self.pad = pad + self.power = power + self.normalized = normalized + self.center = center + self.pad_mode = pad_mode + self.onesided = onesided + self.return_complex = return_complex + + def forward(self, waveform: Tensor) -> Tensor: + r""" + Args: + waveform (Tensor): Tensor of audio of dimension (..., time). + + Returns: + Tensor: Dimension (..., freq, time), where freq is + ``n_fft // 2 + 1`` where ``n_fft`` is the number of + Fourier bins, and time is the number of window hops (n_frame). + """ + return F.spectrogram( + waveform, + self.pad, + self.window, + self.n_fft, + self.hop_length, + self.win_length, + self.power, + self.normalized, + self.center, + self.pad_mode, + self.onesided, + self.return_complex, + ) + + +class InverseSpectrogram(torch.nn.Module): + r"""Create an inverse spectrogram to recover an audio signal from a spectrogram. + + Args: + n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``) + win_length (int or None, optional): Window size. (Default: ``n_fft``) + hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``) + pad (int, optional): Two sided padding of signal. (Default: ``0``) + window_fn (Callable[..., Tensor], optional): A function to create a window tensor + that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) + normalized (bool, optional): Whether the spectrogram was normalized by magnitude after stft. + (Default: ``False``) + wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``) + center (bool, optional): whether the signal in spectrogram was padded on both sides so + that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. + (Default: ``True``) + pad_mode (string, optional): controls the padding method used when + :attr:`center` is ``True``. (Default: ``"reflect"``) + onesided (bool, optional): controls whether spectrogram was used to return half of results to + avoid redundancy (Default: ``True``) + + Example + >>> batch, freq, time = 2, 257, 100 + >>> length = 25344 + >>> spectrogram = torch.randn(batch, freq, time, dtype=torch.cdouble) + >>> transform = transforms.InverseSpectrogram(n_fft=512) + >>> waveform = transform(spectrogram, length) + """ + __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized'] + + def __init__(self, + n_fft: int = 400, + win_length: Optional[int] = None, + hop_length: Optional[int] = None, + pad: int = 0, + window_fn: Callable[..., Tensor] = torch.hann_window, + normalized: bool = False, + wkwargs: Optional[dict] = None, + center: bool = True, + pad_mode: str = "reflect", + onesided: bool = True) -> None: + super(InverseSpectrogram, self).__init__() + self.n_fft = n_fft + # number of FFT bins. the returned STFT result will have n_fft // 2 + 1 + # number of frequencies due to onesided=True in torch.stft + self.win_length = win_length if win_length is not None else n_fft + self.hop_length = hop_length if hop_length is not None else self.win_length // 2 + window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) + self.register_buffer('window', window) + self.pad = pad + self.normalized = normalized + self.center = center + self.pad_mode = pad_mode + self.onesided = onesided + + def forward(self, spectrogram: Tensor, length: Optional[int] = None) -> Tensor: + r""" + Args: + spectrogram (Tensor): Complex tensor of audio of dimension (..., freq, time). + length (int or None, optional): The output length of the waveform. + + Returns: + Tensor: Dimension (..., time), Least squares estimation of the original signal. + """ + return F.inverse_spectrogram( + spectrogram, + length, + self.pad, + self.window, + self.n_fft, + self.hop_length, + self.win_length, + self.normalized, + self.center, + self.pad_mode, + self.onesided, + ) + + +class GriffinLim(torch.nn.Module): + r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation. + + Implementation ported from + *librosa* [:footcite:`brian_mcfee-proc-scipy-2015`], *A fast Griffin-Lim algorithm* [:footcite:`6701851`] + and *Signal estimation from modified short-time Fourier transform* [:footcite:`1172092`]. + + Args: + n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``) + n_iter (int, optional): Number of iteration for phase recovery process. (Default: ``32``) + win_length (int or None, optional): Window size. (Default: ``n_fft``) + hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``) + window_fn (Callable[..., Tensor], optional): A function to create a window tensor + that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) + power (float, optional): Exponent for the magnitude spectrogram, + (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``) + wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``) + momentum (float, optional): The momentum parameter for fast Griffin-Lim. + Setting this to 0 recovers the original Griffin-Lim method. + Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: ``0.99``) + length (int, optional): Array length of the expected output. (Default: ``None``) + rand_init (bool, optional): Initializes phase randomly if True and to zero otherwise. (Default: ``True``) + + Example + >>> batch, freq, time = 2, 257, 100 + >>> spectrogram = torch.randn(batch, freq, time) + >>> transform = transforms.GriffinLim(n_fft=512) + >>> waveform = transform(spectrogram) + """ + __constants__ = ['n_fft', 'n_iter', 'win_length', 'hop_length', 'power', + 'length', 'momentum', 'rand_init'] + + def __init__(self, + n_fft: int = 400, + n_iter: int = 32, + win_length: Optional[int] = None, + hop_length: Optional[int] = None, + window_fn: Callable[..., Tensor] = torch.hann_window, + power: float = 2., + wkwargs: Optional[dict] = None, + momentum: float = 0.99, + length: Optional[int] = None, + rand_init: bool = True) -> None: + super(GriffinLim, self).__init__() + + assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum) + assert momentum >= 0, 'momentum={} < 0'.format(momentum) + + self.n_fft = n_fft + self.n_iter = n_iter + self.win_length = win_length if win_length is not None else n_fft + self.hop_length = hop_length if hop_length is not None else self.win_length // 2 + window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) + self.register_buffer('window', window) + self.length = length + self.power = power + self.momentum = momentum / (1 + momentum) + self.rand_init = rand_init + + def forward(self, specgram: Tensor) -> Tensor: + r""" + Args: + specgram (Tensor): + A magnitude-only STFT spectrogram of dimension (..., freq, frames) + where freq is ``n_fft // 2 + 1``. + + Returns: + Tensor: waveform of (..., time), where time equals the ``length`` parameter if given. + """ + return F.griffinlim(specgram, self.window, self.n_fft, self.hop_length, self.win_length, self.power, + self.n_iter, self.momentum, self.length, self.rand_init) + + +class AmplitudeToDB(torch.nn.Module): + r"""Turn a tensor from the power/amplitude scale to the decibel scale. + + This output depends on the maximum value in the input tensor, and so + may return different values for an audio clip split into snippets vs. a + a full clip. + + Args: + stype (str, optional): scale of input tensor ('power' or 'magnitude'). The + power being the elementwise square of the magnitude. (Default: ``'power'``) + top_db (float or None, optional): minimum negative cut-off in decibels. A reasonable + number is 80. (Default: ``None``) + """ + __constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier'] + + def __init__(self, stype: str = 'power', top_db: Optional[float] = None) -> None: + super(AmplitudeToDB, self).__init__() + self.stype = stype + if top_db is not None and top_db < 0: + raise ValueError('top_db must be positive value') + self.top_db = top_db + self.multiplier = 10.0 if stype == 'power' else 20.0 + self.amin = 1e-10 + self.ref_value = 1.0 + self.db_multiplier = math.log10(max(self.amin, self.ref_value)) + + def forward(self, x: Tensor) -> Tensor: + r"""Numerically stable implementation from Librosa. + + https://librosa.org/doc/latest/generated/librosa.amplitude_to_db.html + + Args: + x (Tensor): Input tensor before being converted to decibel scale. + + Returns: + Tensor: Output tensor in decibel scale. + """ + return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db) + + +class MelScale(torch.nn.Module): + r"""Turn a normal STFT into a mel frequency STFT, using a conversion + matrix. This uses triangular filter banks. + + User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)). + + Args: + n_mels (int, optional): Number of mel filterbanks. (Default: ``128``) + sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) + f_min (float, optional): Minimum frequency. (Default: ``0.``) + f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``) + n_stft (int, optional): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`. (Default: ``201``) + norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band + (area normalization). (Default: ``None``) + mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) + + See also: + :py:func:`torchaudio.functional.melscale_fbanks` - The function used to + generate the filter banks. + """ + __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max'] + + def __init__(self, + n_mels: int = 128, + sample_rate: int = 16000, + f_min: float = 0., + f_max: Optional[float] = None, + n_stft: int = 201, + norm: Optional[str] = None, + mel_scale: str = "htk") -> None: + super(MelScale, self).__init__() + self.n_mels = n_mels + self.sample_rate = sample_rate + self.f_max = f_max if f_max is not None else float(sample_rate // 2) + self.f_min = f_min + self.norm = norm + self.mel_scale = mel_scale + + assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) + fb = F.melscale_fbanks( + n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, + self.mel_scale) + self.register_buffer('fb', fb) + + def forward(self, specgram: Tensor) -> Tensor: + r""" + Args: + specgram (Tensor): A spectrogram STFT of dimension (..., freq, time). + + Returns: + Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time). + """ + + # (..., time, freq) dot (freq, n_mels) -> (..., n_mels, time) + mel_specgram = torch.matmul(specgram.transpose(-1, -2), self.fb).transpose(-1, -2) + + return mel_specgram + + +class InverseMelScale(torch.nn.Module): + r"""Solve for a normal STFT from a mel frequency STFT, using a conversion + matrix. This uses triangular filter banks. + + It minimizes the euclidian norm between the input mel-spectrogram and the product between + the estimated spectrogram and the filter banks using SGD. + + Args: + n_stft (int): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`. + n_mels (int, optional): Number of mel filterbanks. (Default: ``128``) + sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) + f_min (float, optional): Minimum frequency. (Default: ``0.``) + f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``) + max_iter (int, optional): Maximum number of optimization iterations. (Default: ``100000``) + tolerance_loss (float, optional): Value of loss to stop optimization at. (Default: ``1e-5``) + tolerance_change (float, optional): Difference in losses to stop optimization at. (Default: ``1e-8``) + sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``) + norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band + (area normalization). (Default: ``None``) + mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) + """ + __constants__ = ['n_stft', 'n_mels', 'sample_rate', 'f_min', 'f_max', 'max_iter', 'tolerance_loss', + 'tolerance_change', 'sgdargs'] + + def __init__(self, + n_stft: int, + n_mels: int = 128, + sample_rate: int = 16000, + f_min: float = 0., + f_max: Optional[float] = None, + max_iter: int = 100000, + tolerance_loss: float = 1e-5, + tolerance_change: float = 1e-8, + sgdargs: Optional[dict] = None, + norm: Optional[str] = None, + mel_scale: str = "htk") -> None: + super(InverseMelScale, self).__init__() + self.n_mels = n_mels + self.sample_rate = sample_rate + self.f_max = f_max or float(sample_rate // 2) + self.f_min = f_min + self.max_iter = max_iter + self.tolerance_loss = tolerance_loss + self.tolerance_change = tolerance_change + self.sgdargs = sgdargs or {'lr': 0.1, 'momentum': 0.9} + + assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max) + + fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, + norm, mel_scale) + self.register_buffer('fb', fb) + + def forward(self, melspec: Tensor) -> Tensor: + r""" + Args: + melspec (Tensor): A Mel frequency spectrogram of dimension (..., ``n_mels``, time) + + Returns: + Tensor: Linear scale spectrogram of size (..., freq, time) + """ + # pack batch + shape = melspec.size() + melspec = melspec.view(-1, shape[-2], shape[-1]) + + n_mels, time = shape[-2], shape[-1] + freq, _ = self.fb.size() # (freq, n_mels) + melspec = melspec.transpose(-1, -2) + assert self.n_mels == n_mels + + specgram = torch.rand(melspec.size()[0], time, freq, requires_grad=True, + dtype=melspec.dtype, device=melspec.device) + + optim = torch.optim.SGD([specgram], **self.sgdargs) + + loss = float('inf') + for _ in range(self.max_iter): + optim.zero_grad() + diff = melspec - specgram.matmul(self.fb) + new_loss = diff.pow(2).sum(axis=-1).mean() + # take sum over mel-frequency then average over other dimensions + # so that loss threshold is applied par unit timeframe + new_loss.backward() + optim.step() + specgram.data = specgram.data.clamp(min=0) + + new_loss = new_loss.item() + if new_loss < self.tolerance_loss or abs(loss - new_loss) < self.tolerance_change: + break + loss = new_loss + + specgram.requires_grad_(False) + specgram = specgram.clamp(min=0).transpose(-1, -2) + + # unpack batch + specgram = specgram.view(shape[:-2] + (freq, time)) + return specgram + + +class MelSpectrogram(torch.nn.Module): + r"""Create MelSpectrogram for a raw audio signal. + + This is a composition of :py:func:`torchaudio.transforms.Spectrogram` and + and :py:func:`torchaudio.transforms.MelScale`. + + Sources + * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe + * https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html + * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html + + Args: + sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) + n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``) + win_length (int or None, optional): Window size. (Default: ``n_fft``) + hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``) + f_min (float, optional): Minimum frequency. (Default: ``0.``) + f_max (float or None, optional): Maximum frequency. (Default: ``None``) + pad (int, optional): Two sided padding of signal. (Default: ``0``) + n_mels (int, optional): Number of mel filterbanks. (Default: ``128``) + window_fn (Callable[..., Tensor], optional): A function to create a window tensor + that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) + power (float, optional): Exponent for the magnitude spectrogram, + (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``) + normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``) + wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``) + center (bool, optional): whether to pad :attr:`waveform` on both sides so + that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. + (Default: ``True``) + pad_mode (string, optional): controls the padding method used when + :attr:`center` is ``True``. (Default: ``"reflect"``) + onesided (bool, optional): controls whether to return half of results to + avoid redundancy. (Default: ``True``) + norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band + (area normalization). (Default: ``None``) + mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) + + Example + >>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True) + >>> transform = transforms.MelSpectrogram(sample_rate) + >>> mel_specgram = transform(waveform) # (channel, n_mels, time) + + See also: + :py:func:`torchaudio.functional.melscale_fbanks` - The function used to + generate the filter banks. + """ + __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min'] + + def __init__(self, + sample_rate: int = 16000, + n_fft: int = 400, + win_length: Optional[int] = None, + hop_length: Optional[int] = None, + f_min: float = 0., + f_max: Optional[float] = None, + pad: int = 0, + n_mels: int = 128, + window_fn: Callable[..., Tensor] = torch.hann_window, + power: float = 2., + normalized: bool = False, + wkwargs: Optional[dict] = None, + center: bool = True, + pad_mode: str = "reflect", + onesided: bool = True, + norm: Optional[str] = None, + mel_scale: str = "htk") -> None: + super(MelSpectrogram, self).__init__() + self.sample_rate = sample_rate + self.n_fft = n_fft + self.win_length = win_length if win_length is not None else n_fft + self.hop_length = hop_length if hop_length is not None else self.win_length // 2 + self.pad = pad + self.power = power + self.normalized = normalized + self.n_mels = n_mels # number of mel frequency bins + self.f_max = f_max + self.f_min = f_min + self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length, + hop_length=self.hop_length, + pad=self.pad, window_fn=window_fn, power=self.power, + normalized=self.normalized, wkwargs=wkwargs, + center=center, pad_mode=pad_mode, onesided=onesided) + self.mel_scale = MelScale( + self.n_mels, + self.sample_rate, + self.f_min, + self.f_max, + self.n_fft // 2 + 1, + norm, + mel_scale + ) + + def forward(self, waveform: Tensor) -> Tensor: + r""" + Args: + waveform (Tensor): Tensor of audio of dimension (..., time). + + Returns: + Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time). + """ + specgram = self.spectrogram(waveform) + mel_specgram = self.mel_scale(specgram) + return mel_specgram + + +class MFCC(torch.nn.Module): + r"""Create the Mel-frequency cepstrum coefficients from an audio signal. + + By default, this calculates the MFCC on the DB-scaled Mel spectrogram. + This is not the textbook implementation, but is implemented here to + give consistency with librosa. + + This output depends on the maximum value in the input spectrogram, and so + may return different values for an audio clip split into snippets vs. a + a full clip. + + Args: + sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) + n_mfcc (int, optional): Number of mfc coefficients to retain. (Default: ``40``) + dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``) + norm (str, optional): norm to use. (Default: ``'ortho'``) + log_mels (bool, optional): whether to use log-mel spectrograms instead of db-scaled. (Default: ``False``) + melkwargs (dict or None, optional): arguments for MelSpectrogram. (Default: ``None``) + + See also: + :py:func:`torchaudio.functional.melscale_fbanks` - The function used to + generate the filter banks. + """ + __constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels'] + + def __init__(self, + sample_rate: int = 16000, + n_mfcc: int = 40, + dct_type: int = 2, + norm: str = 'ortho', + log_mels: bool = False, + melkwargs: Optional[dict] = None) -> None: + super(MFCC, self).__init__() + supported_dct_types = [2] + if dct_type not in supported_dct_types: + raise ValueError('DCT type not supported: {}'.format(dct_type)) + self.sample_rate = sample_rate + self.n_mfcc = n_mfcc + self.dct_type = dct_type + self.norm = norm + self.top_db = 80.0 + self.amplitude_to_DB = AmplitudeToDB('power', self.top_db) + + melkwargs = melkwargs or {} + self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs) + + if self.n_mfcc > self.MelSpectrogram.n_mels: + raise ValueError('Cannot select more MFCC coefficients than # mel bins') + dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm) + self.register_buffer('dct_mat', dct_mat) + self.log_mels = log_mels + + def forward(self, waveform: Tensor) -> Tensor: + r""" + Args: + waveform (Tensor): Tensor of audio of dimension (..., time). + + Returns: + Tensor: specgram_mel_db of size (..., ``n_mfcc``, time). + """ + mel_specgram = self.MelSpectrogram(waveform) + if self.log_mels: + log_offset = 1e-6 + mel_specgram = torch.log(mel_specgram + log_offset) + else: + mel_specgram = self.amplitude_to_DB(mel_specgram) + + # (..., time, n_mels) dot (n_mels, n_mfcc) -> (..., n_nfcc, time) + mfcc = torch.matmul(mel_specgram.transpose(-1, -2), self.dct_mat).transpose(-1, -2) + return mfcc + + +class LFCC(torch.nn.Module): + r"""Create the linear-frequency cepstrum coefficients from an audio signal. + + By default, this calculates the LFCC on the DB-scaled linear filtered spectrogram. + This is not the textbook implementation, but is implemented here to + give consistency with librosa. + + This output depends on the maximum value in the input spectrogram, and so + may return different values for an audio clip split into snippets vs. a + a full clip. + + Args: + sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) + n_filter (int, optional): Number of linear filters to apply. (Default: ``128``) + n_lfcc (int, optional): Number of lfc coefficients to retain. (Default: ``40``) + f_min (float, optional): Minimum frequency. (Default: ``0.``) + f_max (float or None, optional): Maximum frequency. (Default: ``None``) + dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``) + norm (str, optional): norm to use. (Default: ``'ortho'``) + log_lf (bool, optional): whether to use log-lf spectrograms instead of db-scaled. (Default: ``False``) + speckwargs (dict or None, optional): arguments for Spectrogram. (Default: ``None``) + + + See also: + :py:func:`torchaudio.functional.linear_fbanks` - The function used to + generate the filter banks. + """ + __constants__ = ['sample_rate', 'n_filter', 'n_lfcc', 'dct_type', 'top_db', 'log_lf'] + + def __init__(self, + sample_rate: int = 16000, + n_filter: int = 128, + f_min: float = 0., + f_max: Optional[float] = None, + n_lfcc: int = 40, + dct_type: int = 2, + norm: str = 'ortho', + log_lf: bool = False, + speckwargs: Optional[dict] = None) -> None: + super(LFCC, self).__init__() + supported_dct_types = [2] + if dct_type not in supported_dct_types: + raise ValueError('DCT type not supported: {}'.format(dct_type)) + self.sample_rate = sample_rate + self.f_min = f_min + self.f_max = f_max if f_max is not None else float(sample_rate // 2) + self.n_filter = n_filter + self.n_lfcc = n_lfcc + self.dct_type = dct_type + self.norm = norm + self.top_db = 80.0 + self.amplitude_to_DB = AmplitudeToDB('power', self.top_db) + + speckwargs = speckwargs or {} + self.Spectrogram = Spectrogram(**speckwargs) + + if self.n_lfcc > self.Spectrogram.n_fft: + raise ValueError('Cannot select more LFCC coefficients than # fft bins') + + filter_mat = F.linear_fbanks( + n_freqs=self.Spectrogram.n_fft // 2 + 1, + f_min=self.f_min, + f_max=self.f_max, + n_filter=self.n_filter, + sample_rate=self.sample_rate, + ) + self.register_buffer("filter_mat", filter_mat) + + dct_mat = F.create_dct(self.n_lfcc, self.n_filter, self.norm) + self.register_buffer('dct_mat', dct_mat) + self.log_lf = log_lf + + def forward(self, waveform: Tensor) -> Tensor: + r""" + Args: + waveform (Tensor): Tensor of audio of dimension (..., time). + + Returns: + Tensor: Linear Frequency Cepstral Coefficients of size (..., ``n_lfcc``, time). + """ + specgram = self.Spectrogram(waveform) + + # (..., time, freq) dot (freq, n_filter) -> (..., n_filter, time) + specgram = torch.matmul(specgram.transpose(-1, -2), self.filter_mat).transpose(-1, -2) + + if self.log_lf: + log_offset = 1e-6 + specgram = torch.log(specgram + log_offset) + else: + specgram = self.amplitude_to_DB(specgram) + + # (..., time, n_filter) dot (n_filter, n_lfcc) -> (..., n_lfcc, time) + lfcc = torch.matmul(specgram.transpose(-1, -2), self.dct_mat).transpose(-1, -2) + return lfcc + + +class MuLawEncoding(torch.nn.Module): + r"""Encode signal based on mu-law companding. For more info see the + `Wikipedia Entry `_ + + This algorithm assumes the signal has been scaled to between -1 and 1 and + returns a signal encoded with values from 0 to quantization_channels - 1 + + Args: + quantization_channels (int, optional): Number of channels. (Default: ``256``) + + Example + >>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True) + >>> transform = torchaudio.transforms.MuLawEncoding(quantization_channels=512) + >>> mulawtrans = transform(waveform) + + """ + __constants__ = ['quantization_channels'] + + def __init__(self, quantization_channels: int = 256) -> None: + super(MuLawEncoding, self).__init__() + self.quantization_channels = quantization_channels + + def forward(self, x: Tensor) -> Tensor: + r""" + Args: + x (Tensor): A signal to be encoded. + + Returns: + Tensor: An encoded signal. + """ + return F.mu_law_encoding(x, self.quantization_channels) + + +class MuLawDecoding(torch.nn.Module): + r"""Decode mu-law encoded signal. For more info see the + `Wikipedia Entry `_ + + This expects an input with values between 0 and quantization_channels - 1 + and returns a signal scaled between -1 and 1. + + Args: + quantization_channels (int, optional): Number of channels. (Default: ``256``) + + Example + >>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True) + >>> transform = torchaudio.transforms.MuLawDecoding(quantization_channels=512) + >>> mulawtrans = transform(waveform) + """ + __constants__ = ['quantization_channels'] + + def __init__(self, quantization_channels: int = 256) -> None: + super(MuLawDecoding, self).__init__() + self.quantization_channels = quantization_channels + + def forward(self, x_mu: Tensor) -> Tensor: + r""" + Args: + x_mu (Tensor): A mu-law encoded signal which needs to be decoded. + + Returns: + Tensor: The signal decoded. + """ + return F.mu_law_decoding(x_mu, self.quantization_channels) + + +class Resample(torch.nn.Module): + r"""Resample a signal from one frequency to another. A resampling method can be given. + + Note: + If resampling on waveforms of higher precision than float32, there may be a small loss of precision + because the kernel is cached once as float32. If high precision resampling is important for your application, + the functional form will retain higher precision, but run slower because it does not cache the kernel. + Alternatively, you could rewrite a transform that caches a higher precision kernel. + + Args: + orig_freq (int, optional): The original frequency of the signal. (Default: ``16000``) + new_freq (int, optional): The desired frequency. (Default: ``16000``) + resampling_method (str, optional): The resampling method to use. + Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``) + lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper + but less efficient. (Default: ``6``) + rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist. + Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``) + beta (float or None, optional): The shape parameter used for kaiser window. + dtype (torch.device, optional): + Determnines the precision that resampling kernel is pre-computed and cached. If not provided, + kernel is computed with ``torch.float64`` then cached as ``torch.float32``. + If you need higher precision, provide ``torch.float64``, and the pre-computed kernel is computed and + cached as ``torch.float64``. If you use resample with lower precision, then instead of providing this + providing this argument, please use ``Resample.to(dtype)``, so that the kernel generation is still + carried out on ``torch.float64``. + + Example + >>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True) + >>> transform = transforms.Resample(sample_rate, sample_rate/10) + >>> waveform = transform(waveform) + """ + + def __init__( + self, + orig_freq: int = 16000, + new_freq: int = 16000, + resampling_method: str = 'sinc_interpolation', + lowpass_filter_width: int = 6, + rolloff: float = 0.99, + beta: Optional[float] = None, + *, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__() + + self.orig_freq = orig_freq + self.new_freq = new_freq + self.gcd = math.gcd(int(self.orig_freq), int(self.new_freq)) + self.resampling_method = resampling_method + self.lowpass_filter_width = lowpass_filter_width + self.rolloff = rolloff + self.beta = beta + + if self.orig_freq != self.new_freq: + kernel, self.width = _get_sinc_resample_kernel( + self.orig_freq, self.new_freq, self.gcd, + self.lowpass_filter_width, self.rolloff, + self.resampling_method, beta, dtype=dtype) + self.register_buffer('kernel', kernel) + + def forward(self, waveform: Tensor) -> Tensor: + r""" + Args: + waveform (Tensor): Tensor of audio of dimension (..., time). + + Returns: + Tensor: Output signal of dimension (..., time). + """ + if self.orig_freq == self.new_freq: + return waveform + return _apply_sinc_resample_kernel( + waveform, self.orig_freq, self.new_freq, self.gcd, + self.kernel, self.width) + + +class ComplexNorm(torch.nn.Module): + r"""Compute the norm of complex tensor input. + + Args: + power (float, optional): Power of the norm. (Default: to ``1.0``) + + Example + >>> complex_tensor = ... # Tensor shape of (…, complex=2) + >>> transform = transforms.ComplexNorm(power=2) + >>> complex_norm = transform(complex_tensor) + """ + __constants__ = ['power'] + + def __init__(self, power: float = 1.0) -> None: + warnings.warn( + 'torchaudio.transforms.ComplexNorm has been deprecated ' + 'and will be removed from future release.' + 'Please convert the input Tensor to complex type with `torch.view_as_complex` then ' + 'use `torch.abs` and `torch.angle`. ' + 'Please refer to https://github.com/pytorch/audio/issues/1337 ' + "for more details about torchaudio's plan to migrate to native complex type." + ) + super(ComplexNorm, self).__init__() + self.power = power + + def forward(self, complex_tensor: Tensor) -> Tensor: + r""" + Args: + complex_tensor (Tensor): Tensor shape of `(..., complex=2)`. + + Returns: + Tensor: norm of the input tensor, shape of `(..., )`. + """ + return F.complex_norm(complex_tensor, self.power) + + +class ComputeDeltas(torch.nn.Module): + r"""Compute delta coefficients of a tensor, usually a spectrogram. + + See `torchaudio.functional.compute_deltas` for more details. + + Args: + win_length (int, optional): The window length used for computing delta. (Default: ``5``) + mode (str, optional): Mode parameter passed to padding. (Default: ``'replicate'``) + """ + __constants__ = ['win_length'] + + def __init__(self, win_length: int = 5, mode: str = "replicate") -> None: + super(ComputeDeltas, self).__init__() + self.win_length = win_length + self.mode = mode + + def forward(self, specgram: Tensor) -> Tensor: + r""" + Args: + specgram (Tensor): Tensor of audio of dimension (..., freq, time). + + Returns: + Tensor: Tensor of deltas of dimension (..., freq, time). + """ + return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode) + + +class TimeStretch(torch.nn.Module): + r"""Stretch stft in time without modifying pitch for a given rate. + + Proposed in *SpecAugment* [:footcite:`specaugment`]. + + Args: + hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``) + n_freq (int, optional): number of filter banks from stft. (Default: ``201``) + fixed_rate (float or None, optional): rate to speed up or slow down by. + If None is provided, rate must be passed to the forward method. (Default: ``None``) + + Example + >>> spectrogram = torchaudio.transforms.Spectrogram() + >>> stretch = torchaudio.transforms.TimeStretch() + >>> + >>> original = spectrogram(waveform) + >>> streched_1_2 = stretch(original, 1.2) + >>> streched_0_9 = stretch(original, 0.9) + + .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_1.png + :width: 600 + :alt: Spectrogram streched by 1.2 + + .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_2.png + :width: 600 + :alt: The original spectrogram + + .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_3.png + :width: 600 + :alt: Spectrogram streched by 0.9 + + """ + __constants__ = ['fixed_rate'] + + def __init__(self, + hop_length: Optional[int] = None, + n_freq: int = 201, + fixed_rate: Optional[float] = None) -> None: + super(TimeStretch, self).__init__() + + self.fixed_rate = fixed_rate + + n_fft = (n_freq - 1) * 2 + hop_length = hop_length if hop_length is not None else n_fft // 2 + self.register_buffer('phase_advance', torch.linspace(0, math.pi * hop_length, n_freq)[..., None]) + + def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = None) -> Tensor: + r""" + Args: + complex_specgrams (Tensor): + Either a real tensor of dimension of `(..., freq, num_frame, complex=2)` + or a tensor of dimension `(..., freq, num_frame)` with complex dtype. + overriding_rate (float or None, optional): speed up to apply to this batch. + If no rate is passed, use ``self.fixed_rate``. (Default: ``None``) + + Returns: + Tensor: + Stretched spectrogram. The resulting tensor is of the same dtype as the input + spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``. + """ + if overriding_rate is None: + if self.fixed_rate is None: + raise ValueError( + "If no fixed_rate is specified, must pass a valid rate to the forward method.") + rate = self.fixed_rate + else: + rate = overriding_rate + return F.phase_vocoder(complex_specgrams, rate, self.phase_advance) + + +class Fade(torch.nn.Module): + r"""Add a fade in and/or fade out to an waveform. + + Args: + fade_in_len (int, optional): Length of fade-in (time frames). (Default: ``0``) + fade_out_len (int, optional): Length of fade-out (time frames). (Default: ``0``) + fade_shape (str, optional): Shape of fade. Must be one of: "quarter_sine", + "half_sine", "linear", "logarithmic", "exponential". (Default: ``"linear"``) + + Example + >>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True) + >>> transform = transforms.Fade(fade_in_len=sample_rate, fade_out_len=2 * sample_rate, fade_shape='linear') + >>> faded_waveform = transform(waveform) + """ + + def __init__(self, + fade_in_len: int = 0, + fade_out_len: int = 0, + fade_shape: str = "linear") -> None: + super(Fade, self).__init__() + self.fade_in_len = fade_in_len + self.fade_out_len = fade_out_len + self.fade_shape = fade_shape + + def forward(self, waveform: Tensor) -> Tensor: + r""" + Args: + waveform (Tensor): Tensor of audio of dimension `(..., time)`. + + Returns: + Tensor: Tensor of audio of dimension `(..., time)`. + """ + waveform_length = waveform.size()[-1] + device = waveform.device + return ( + self._fade_in(waveform_length, device) + * self._fade_out(waveform_length, device) + * waveform + ) + + def _fade_in(self, waveform_length: int, device: torch.device) -> Tensor: + fade = torch.linspace(0, 1, self.fade_in_len, device=device) + ones = torch.ones(waveform_length - self.fade_in_len, device=device) + + if self.fade_shape == "linear": + fade = fade + + if self.fade_shape == "exponential": + fade = torch.pow(2, (fade - 1)) * fade + + if self.fade_shape == "logarithmic": + fade = torch.log10(.1 + fade) + 1 + + if self.fade_shape == "quarter_sine": + fade = torch.sin(fade * math.pi / 2) + + if self.fade_shape == "half_sine": + fade = torch.sin(fade * math.pi - math.pi / 2) / 2 + 0.5 + + return torch.cat((fade, ones)).clamp_(0, 1) + + def _fade_out(self, waveform_length: int, device: torch.device) -> Tensor: + fade = torch.linspace(0, 1, self.fade_out_len, device=device) + ones = torch.ones(waveform_length - self.fade_out_len, device=device) + + if self.fade_shape == "linear": + fade = - fade + 1 + + if self.fade_shape == "exponential": + fade = torch.pow(2, - fade) * (1 - fade) + + if self.fade_shape == "logarithmic": + fade = torch.log10(1.1 - fade) + 1 + + if self.fade_shape == "quarter_sine": + fade = torch.sin(fade * math.pi / 2 + math.pi / 2) + + if self.fade_shape == "half_sine": + fade = torch.sin(fade * math.pi + math.pi / 2) / 2 + 0.5 + + return torch.cat((ones, fade)).clamp_(0, 1) + + +class _AxisMasking(torch.nn.Module): + r"""Apply masking to a spectrogram. + + Args: + mask_param (int): Maximum possible length of the mask. + axis (int): What dimension the mask is applied on. + iid_masks (bool): Applies iid masks to each of the examples in the batch dimension. + This option is applicable only when the input tensor is 4D. + """ + __constants__ = ['mask_param', 'axis', 'iid_masks'] + + def __init__(self, mask_param: int, axis: int, iid_masks: bool) -> None: + + super(_AxisMasking, self).__init__() + self.mask_param = mask_param + self.axis = axis + self.iid_masks = iid_masks + + def forward(self, specgram: Tensor, mask_value: float = 0.) -> Tensor: + r""" + Args: + specgram (Tensor): Tensor of dimension `(..., freq, time)`. + mask_value (float): Value to assign to the masked columns. + + Returns: + Tensor: Masked spectrogram of dimensions `(..., freq, time)`. + """ + # if iid_masks flag marked and specgram has a batch dimension + if self.iid_masks and specgram.dim() == 4: + return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1) + else: + return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis) + + +class FrequencyMasking(_AxisMasking): + r"""Apply masking to a spectrogram in the frequency domain. + + Proposed in *SpecAugment* [:footcite:`specaugment`]. + + Args: + freq_mask_param (int): maximum possible length of the mask. + Indices uniformly sampled from [0, freq_mask_param). + iid_masks (bool, optional): whether to apply different masks to each + example/channel in the batch. (Default: ``False``) + This option is applicable only when the input tensor is 4D. + + Example + >>> spectrogram = torchaudio.transforms.Spectrogram() + >>> masking = torchaudio.transforms.FrequencyMasking(freq_mask_param=80) + >>> + >>> original = spectrogram(waveform) + >>> masked = masking(original) + + .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_freq_masking1.png + :alt: The original spectrogram + + .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_freq_masking2.png + :alt: The spectrogram masked along frequency axis + """ + + def __init__(self, freq_mask_param: int, iid_masks: bool = False) -> None: + super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks) + + +class TimeMasking(_AxisMasking): + r"""Apply masking to a spectrogram in the time domain. + + Proposed in *SpecAugment* [:footcite:`specaugment`]. + + Args: + time_mask_param (int): maximum possible length of the mask. + Indices uniformly sampled from [0, time_mask_param). + iid_masks (bool, optional): whether to apply different masks to each + example/channel in the batch. (Default: ``False``) + This option is applicable only when the input tensor is 4D. + + Example + >>> spectrogram = torchaudio.transforms.Spectrogram() + >>> masking = torchaudio.transforms.TimeMasking(time_mask_param=80) + >>> + >>> original = spectrogram(waveform) + >>> masked = masking(original) + + .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_masking1.png + :alt: The original spectrogram + + .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_masking2.png + :alt: The spectrogram masked along time axis + """ + + def __init__(self, time_mask_param: int, iid_masks: bool = False) -> None: + super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks) + + +class Vol(torch.nn.Module): + r"""Add a volume to an waveform. + + Args: + gain (float): Interpreted according to the given gain_type: + If ``gain_type`` = ``amplitude``, ``gain`` is a positive amplitude ratio. + If ``gain_type`` = ``power``, ``gain`` is a power (voltage squared). + If ``gain_type`` = ``db``, ``gain`` is in decibels. + gain_type (str, optional): Type of gain. One of: ``amplitude``, ``power``, ``db`` (Default: ``amplitude``) + """ + + def __init__(self, gain: float, gain_type: str = 'amplitude'): + super(Vol, self).__init__() + self.gain = gain + self.gain_type = gain_type + + if gain_type in ['amplitude', 'power'] and gain < 0: + raise ValueError("If gain_type = amplitude or power, gain must be positive.") + + def forward(self, waveform: Tensor) -> Tensor: + r""" + Args: + waveform (Tensor): Tensor of audio of dimension `(..., time)`. + + Returns: + Tensor: Tensor of audio of dimension `(..., time)`. + """ + if self.gain_type == "amplitude": + waveform = waveform * self.gain + + if self.gain_type == "db": + waveform = F.gain(waveform, self.gain) + + if self.gain_type == "power": + waveform = F.gain(waveform, 10 * math.log10(self.gain)) + + return torch.clamp(waveform, -1, 1) + + +class SlidingWindowCmn(torch.nn.Module): + r""" + Apply sliding-window cepstral mean (and optionally variance) normalization per utterance. + + Args: + cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600) + min_cmn_window (int, optional): Minimum CMN window used at start of decoding (adds latency only at start). + Only applicable if center == false, ignored if center==true (int, default = 100) + center (bool, optional): If true, use a window centered on the current frame + (to the extent possible, modulo end effects). If false, window is to the left. (bool, default = false) + norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false) + """ + + def __init__(self, + cmn_window: int = 600, + min_cmn_window: int = 100, + center: bool = False, + norm_vars: bool = False) -> None: + super().__init__() + self.cmn_window = cmn_window + self.min_cmn_window = min_cmn_window + self.center = center + self.norm_vars = norm_vars + + def forward(self, specgram: Tensor) -> Tensor: + r""" + Args: + specgram (Tensor): Tensor of spectrogram of dimension `(..., time, freq)`. + + Returns: + Tensor: Tensor of spectrogram of dimension `(..., time, freq)`. + """ + cmn_specgram = F.sliding_window_cmn( + specgram, self.cmn_window, self.min_cmn_window, self.center, self.norm_vars) + return cmn_specgram + + +class Vad(torch.nn.Module): + r"""Voice Activity Detector. Similar to SoX implementation. + Attempts to trim silence and quiet background sounds from the ends of recordings of speech. + The algorithm currently uses a simple cepstral power measurement to detect voice, + so may be fooled by other things, especially music. + + The effect can trim only from the front of the audio, + so in order to trim from the back, the reverse effect must also be used. + + Args: + sample_rate (int): Sample rate of audio signal. + trigger_level (float, optional): The measurement level used to trigger activity detection. + This may need to be cahnged depending on the noise level, signal level, + and other characteristics of the input audio. (Default: 7.0) + trigger_time (float, optional): The time constant (in seconds) + used to help ignore short bursts of sound. (Default: 0.25) + search_time (float, optional): The amount of audio (in seconds) + to search for quieter/shorter bursts of audio to include prior + to the detected trigger point. (Default: 1.0) + allowed_gap (float, optional): The allowed gap (in seconds) between + quiteter/shorter bursts of audio to include prior + to the detected trigger point. (Default: 0.25) + pre_trigger_time (float, optional): The amount of audio (in seconds) to preserve + before the trigger point and any found quieter/shorter bursts. (Default: 0.0) + boot_time (float, optional) The algorithm (internally) uses adaptive noise + estimation/reduction in order to detect the start of the wanted audio. + This option sets the time for the initial noise estimate. (Default: 0.35) + noise_up_time (float, optional) Time constant used by the adaptive noise estimator + for when the noise level is increasing. (Default: 0.1) + noise_down_time (float, optional) Time constant used by the adaptive noise estimator + for when the noise level is decreasing. (Default: 0.01) + noise_reduction_amount (float, optional) Amount of noise reduction to use in + the detection algorithm (e.g. 0, 0.5, ...). (Default: 1.35) + measure_freq (float, optional) Frequency of the algorithm’s + processing/measurements. (Default: 20.0) + measure_duration: (float or None, optional) Measurement duration. + (Default: Twice the measurement period; i.e. with overlap.) + measure_smooth_time (float, optional) Time constant used to smooth + spectral measurements. (Default: 0.4) + hp_filter_freq (float, optional) "Brick-wall" frequency of high-pass filter applied + at the input to the detector algorithm. (Default: 50.0) + lp_filter_freq (float, optional) "Brick-wall" frequency of low-pass filter applied + at the input to the detector algorithm. (Default: 6000.0) + hp_lifter_freq (float, optional) "Brick-wall" frequency of high-pass lifter used + in the detector algorithm. (Default: 150.0) + lp_lifter_freq (float, optional) "Brick-wall" frequency of low-pass lifter used + in the detector algorithm. (Default: 2000.0) + + Reference: + - http://sox.sourceforge.net/sox.html + """ + + def __init__(self, + sample_rate: int, + trigger_level: float = 7.0, + trigger_time: float = 0.25, + search_time: float = 1.0, + allowed_gap: float = 0.25, + pre_trigger_time: float = 0.0, + boot_time: float = .35, + noise_up_time: float = .1, + noise_down_time: float = .01, + noise_reduction_amount: float = 1.35, + measure_freq: float = 20.0, + measure_duration: Optional[float] = None, + measure_smooth_time: float = .4, + hp_filter_freq: float = 50., + lp_filter_freq: float = 6000., + hp_lifter_freq: float = 150., + lp_lifter_freq: float = 2000.) -> None: + super().__init__() + + self.sample_rate = sample_rate + self.trigger_level = trigger_level + self.trigger_time = trigger_time + self.search_time = search_time + self.allowed_gap = allowed_gap + self.pre_trigger_time = pre_trigger_time + self.boot_time = boot_time + self.noise_up_time = noise_up_time + self.noise_down_time = noise_down_time + self.noise_reduction_amount = noise_reduction_amount + self.measure_freq = measure_freq + self.measure_duration = measure_duration + self.measure_smooth_time = measure_smooth_time + self.hp_filter_freq = hp_filter_freq + self.lp_filter_freq = lp_filter_freq + self.hp_lifter_freq = hp_lifter_freq + self.lp_lifter_freq = lp_lifter_freq + + def forward(self, waveform: Tensor) -> Tensor: + r""" + Args: + waveform (Tensor): Tensor of audio of dimension `(channels, time)` or `(time)` + Tensor of shape `(channels, time)` is treated as a multi-channel recording + of the same event and the resulting output will be trimmed to the earliest + voice activity in any channel. + """ + return F.vad( + waveform=waveform, + sample_rate=self.sample_rate, + trigger_level=self.trigger_level, + trigger_time=self.trigger_time, + search_time=self.search_time, + allowed_gap=self.allowed_gap, + pre_trigger_time=self.pre_trigger_time, + boot_time=self.boot_time, + noise_up_time=self.noise_up_time, + noise_down_time=self.noise_down_time, + noise_reduction_amount=self.noise_reduction_amount, + measure_freq=self.measure_freq, + measure_duration=self.measure_duration, + measure_smooth_time=self.measure_smooth_time, + hp_filter_freq=self.hp_filter_freq, + lp_filter_freq=self.lp_filter_freq, + hp_lifter_freq=self.hp_lifter_freq, + lp_lifter_freq=self.lp_lifter_freq, + ) + + +class SpectralCentroid(torch.nn.Module): + r"""Compute the spectral centroid for each channel along the time axis. + + The spectral centroid is defined as the weighted average of the + frequency values, weighted by their magnitude. + + Args: + sample_rate (int): Sample rate of audio signal. + n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``) + win_length (int or None, optional): Window size. (Default: ``n_fft``) + hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``) + pad (int, optional): Two sided padding of signal. (Default: ``0``) + window_fn (Callable[..., Tensor], optional): A function to create a window tensor + that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) + wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``) + + Example + >>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True) + >>> transform = transforms.SpectralCentroid(sample_rate) + >>> spectral_centroid = transform(waveform) # (channel, time) + """ + __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad'] + + def __init__(self, + sample_rate: int, + n_fft: int = 400, + win_length: Optional[int] = None, + hop_length: Optional[int] = None, + pad: int = 0, + window_fn: Callable[..., Tensor] = torch.hann_window, + wkwargs: Optional[dict] = None) -> None: + super(SpectralCentroid, self).__init__() + self.sample_rate = sample_rate + self.n_fft = n_fft + self.win_length = win_length if win_length is not None else n_fft + self.hop_length = hop_length if hop_length is not None else self.win_length // 2 + window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) + self.register_buffer('window', window) + self.pad = pad + + def forward(self, waveform: Tensor) -> Tensor: + r""" + Args: + waveform (Tensor): Tensor of audio of dimension `(..., time)`. + + Returns: + Tensor: Spectral Centroid of size `(..., time)`. + """ + + return F.spectral_centroid(waveform, self.sample_rate, self.pad, self.window, self.n_fft, self.hop_length, + self.win_length) + + +class PitchShift(torch.nn.Module): + r"""Shift the pitch of a waveform by ``n_steps`` steps. + + Args: + waveform (Tensor): The input waveform of shape `(..., time)`. + sample_rate (int): Sample rate of `waveform`. + n_steps (int): The (fractional) steps to shift `waveform`. + bins_per_octave (int, optional): The number of steps per octave (Default : ``12``). + n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins (Default: ``512``). + win_length (int or None, optional): Window size. If None, then ``n_fft`` is used. (Default: ``None``). + hop_length (int or None, optional): Length of hop between STFT windows. If None, then ``win_length // 4`` + is used (Default: ``None``). + window (Tensor or None, optional): Window tensor that is applied/multiplied to each frame/window. + If None, then ``torch.hann_window(win_length)`` is used (Default: ``None``). + + Example + >>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True) + >>> transform = transforms.PitchShift(sample_rate, 4) + >>> waveform_shift = transform(waveform) # (channel, time) + """ + __constants__ = ['sample_rate', 'n_steps', 'bins_per_octave', 'n_fft', 'win_length', 'hop_length'] + + def __init__(self, + sample_rate: int, + n_steps: int, + bins_per_octave: int = 12, + n_fft: int = 512, + win_length: Optional[int] = None, + hop_length: Optional[int] = None, + window_fn: Callable[..., Tensor] = torch.hann_window, + wkwargs: Optional[dict] = None) -> None: + super(PitchShift, self).__init__() + self.n_steps = n_steps + self.bins_per_octave = bins_per_octave + self.sample_rate = sample_rate + self.n_fft = n_fft + self.win_length = win_length if win_length is not None else n_fft + self.hop_length = hop_length if hop_length is not None else self.win_length // 4 + window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) + self.register_buffer('window', window) + + def forward(self, waveform: Tensor) -> Tensor: + r""" + Args: + waveform (Tensor): Tensor of audio of dimension `(..., time)`. + + Returns: + Tensor: The pitch-shifted audio of shape `(..., time)`. + """ + + return F.pitch_shift(waveform, self.sample_rate, self.n_steps, self.bins_per_octave, self.n_fft, + self.win_length, self.hop_length, self.window) + + +class RNNTLoss(torch.nn.Module): + """Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks* + [:footcite:`graves2012sequence`]. + The RNN Transducer loss extends the CTC loss by defining a distribution over output + sequences of all lengths, and by jointly modelling both input-output and output-output + dependencies. + + Args: + blank (int, optional): blank label (Default: ``-1``) + clamp (float, optional): clamp for gradients (Default: ``-1``) + reduction (string, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``) + + Example + >>> # Hypothetical values + >>> logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1], + >>> [0.1, 0.1, 0.6, 0.1, 0.1], + >>> [0.1, 0.1, 0.2, 0.8, 0.1]], + >>> [[0.1, 0.6, 0.1, 0.1, 0.1], + >>> [0.1, 0.1, 0.2, 0.1, 0.1], + >>> [0.7, 0.1, 0.2, 0.1, 0.1]]]], + >>> dtype=torch.float32, + >>> requires_grad=True) + >>> targets = torch.tensor([[1, 2]], dtype=torch.int) + >>> logit_lengths = torch.tensor([2], dtype=torch.int) + >>> target_lengths = torch.tensor([2], dtype=torch.int) + >>> transform = transforms.RNNTLoss(blank=0) + >>> loss = transform(logits, targets, logit_lengths, target_lengths) + >>> loss.backward() + """ + + def __init__( + self, + blank: int = -1, + clamp: float = -1., + reduction: str = "mean", + ): + super().__init__() + self.blank = blank + self.clamp = clamp + self.reduction = reduction + + def forward( + self, + logits: Tensor, + targets: Tensor, + logit_lengths: Tensor, + target_lengths: Tensor, + ): + """ + Args: + logits (Tensor): Tensor of dimension `(batch, max seq length, max target length + 1, class)` + containing output from joiner + targets (Tensor): Tensor of dimension `(batch, max target length)` containing targets with zero padded + logit_lengths (Tensor): Tensor of dimension `(batch)` containing lengths of each sequence from encoder + target_lengths (Tensor): Tensor of dimension `(batch)` containing lengths of targets for each sequence + Returns: + Tensor: Loss with the reduction option applied. If ``reduction`` is ``'none'``, then size (batch), + otherwise scalar. + """ + return F.rnnt_loss( + logits, + targets, + logit_lengths, + target_lengths, + self.blank, + self.clamp, + self.reduction + ) + + +def _get_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch.Tensor: + r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions. + + Args: + input (torch.Tensor): Tensor of dimension `(..., channel, channel)` + dim1 (int, optional): the first dimension of the diagonal matrix + (Default: -1) + dim2 (int, optional): the second dimension of the diagonal matrix + (Default: -2) + + Returns: + torch.Tensor: trace of the input Tensor + """ + assert input.ndim >= 2, "The dimension of the tensor must be at least 2." + assert input.shape[dim1] == input.shape[dim2],\ + "The size of ``dim1`` and ``dim2`` must be the same." + input = torch.diagonal(input, 0, dim1=dim1, dim2=dim2) + return input.sum(dim=-1) + + +class PSD(torch.nn.Module): + r"""Compute cross-channel power spectral density (PSD) matrix. + + Args: + multi_mask (bool, optional): whether to use multi-channel Time-Frequency masks. (Default: ``False``) + normalize (bool, optional): whether normalize the mask along the time dimension. + eps (float, optional): a value added to the denominator in mask normalization. (Default: 1e-15) + """ + + def __init__(self, multi_mask: bool = False, normalize: bool = True, eps: float = 1e-15): + super().__init__() + self.multi_mask = multi_mask + self.normalize = normalize + self.eps = eps + + def forward(self, specgram: torch.Tensor, mask: Optional[torch.Tensor] = None): + """ + Args: + specgram (torch.Tensor): multi-channel complex-valued STFT matrix. + Tensor of dimension `(..., channel, freq, time)` + mask (torch.Tensor or None, optional): Time-Frequency mask for normalization. + Tensor of dimension `(..., freq, time)` if multi_mask is ``False`` or + of dimension `(..., channel, freq, time)` if multi_mask is ``True`` + + Returns: + Tensor: PSD matrix of the input STFT matrix. + Tensor of dimension `(..., freq, channel, channel)` + """ + # outer product: + # (..., ch_1, freq, time) x (..., ch_2, freq, time) -> (..., time, ch_1, ch_2) + psd = torch.einsum("...cft,...eft->...ftce", [specgram, specgram.conj()]) + + if mask is not None: + if self.multi_mask: + # Averaging mask along channel dimension + mask = mask.mean(dim=-3) # (..., freq, time) + + # Normalized mask along time dimension: + if self.normalize: + mask = mask / (mask.sum(dim=-1, keepdim=True) + self.eps) + + psd = psd * mask.unsqueeze(-1).unsqueeze(-1) + + psd = psd.sum(dim=-3) + return psd + + +class MVDR(torch.nn.Module): + """Minimum Variance Distortionless Response (MVDR) module that performs MVDR beamforming with Time-Frequency masks. + + Based on https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/beamformer.py + + We provide three solutions of MVDR beamforming. One is based on *reference channel selection* + [:footcite:`souden2009optimal`] (``solution=ref_channel``). + + .. math:: + \\textbf{w}_{\\text{MVDR}}(f) =\ + \\frac{{{\\bf{\\Phi}_{\\textbf{NN}}^{-1}}(f){\\bf{\\Phi}_{\\textbf{SS}}}}(f)}\ + {\\text{Trace}({{{\\bf{\\Phi}_{\\textbf{NN}}^{-1}}(f) \\bf{\\Phi}_{\\textbf{SS}}}(f))}}\\bm{u} + + where :math:`\\bf{\\Phi}_{\\textbf{SS}}` and :math:`\\bf{\\Phi}_{\\textbf{NN}}` are the covariance\ + matrices of speech and noise, respectively. :math:`\\bf{u}` is an one-hot vector to determine the\ + reference channel. + + The other two solutions are based on the steering vector (``solution=stv_evd`` or ``solution=stv_power``). + + .. math:: + \\textbf{w}_{\\text{MVDR}}(f) =\ + \\frac{{{\\bf{\\Phi}_{\\textbf{NN}}^{-1}}(f){\\bm{v}}(f)}}\ + {{\\bm{v}^{\\mathsf{H}}}(f){\\bf{\\Phi}_{\\textbf{NN}}^{-1}}(f){\\bm{v}}(f)} + + where :math:`\\bm{v}` is the acoustic transfer function or the steering vector.\ + :math:`.^{\\mathsf{H}}` denotes the Hermitian Conjugate operation. + + We apply either *eigenvalue decomposition* + [:footcite:`higuchi2016robust`] or the *power method* [:footcite:`mises1929praktische`] to get the + steering vector from the PSD matrix of speech. + + After estimating the beamforming weight, the enhanced Short-time Fourier Transform (STFT) is obtained by + + .. math:: + \\hat{\\bf{S}} = {\\bf{w}^\\mathsf{H}}{\\bf{Y}}, {\\bf{w}} \\in \\mathbb{C}^{M \\times F} + + where :math:`\\bf{Y}` and :math:`\\hat{\\bf{S}}` are the STFT of the multi-channel noisy speech and\ + the single-channel enhanced speech, respectively. + + For online streaming audio, we provide a *recursive method* [:footcite:`higuchi2017online`] to update the + PSD matrices of speech and noise, respectively. + + Args: + ref_channel (int, optional): the reference channel for beamforming. (Default: ``0``) + solution (str, optional): the solution to get MVDR weight. + Options: [``ref_channel``, ``stv_evd``, ``stv_power``]. (Default: ``ref_channel``) + multi_mask (bool, optional): whether to use multi-channel Time-Frequency masks. (Default: ``False``) + diag_loading (bool, optional): whether apply diagonal loading on the psd matrix of noise. + (Default: ``True``) + diag_eps (float, optional): the coefficient multipied to the identity matrix for diagonal loading. + (Default: 1e-7) + online (bool, optional): whether to update the mvdr vector based on the previous psd matrices. + (Default: ``False``) + + Note: + The MVDR Module requires the input STFT to be double precision (``torch.complex128`` or ``torch.cdouble``), + to improve the numerical stability. You can downgrade the precision to ``torch.float`` after generating the + enhanced waveform for ASR joint training. + + Note: + If you use ``stv_evd`` solution, the gradient of the same input may not be identical if the + eigenvalues of the PSD matrix are not distinct (i.e. some eigenvalues are close or identical). + """ + + def __init__( + self, + ref_channel: int = 0, + solution: str = "ref_channel", + multi_mask: bool = False, + diag_loading: bool = True, + diag_eps: float = 1e-7, + online: bool = False, + ): + super().__init__() + assert solution in ["ref_channel", "stv_evd", "stv_power"],\ + "Unknown solution provided. Must be one of [``ref_channel``, ``stv_evd``, ``stv_power``]." + self.ref_channel = ref_channel + self.solution = solution + self.multi_mask = multi_mask + self.diag_loading = diag_loading + self.diag_eps = diag_eps + self.online = online + self.psd = PSD(multi_mask) + + psd_s: torch.Tensor = torch.zeros(1) + psd_n: torch.Tensor = torch.zeros(1) + mask_sum_s: torch.Tensor = torch.zeros(1) + mask_sum_n: torch.Tensor = torch.zeros(1) + self.register_buffer('psd_s', psd_s) + self.register_buffer('psd_n', psd_n) + self.register_buffer('mask_sum_s', mask_sum_s) + self.register_buffer('mask_sum_n', mask_sum_n) + + def _get_updated_mvdr_vector( + self, + psd_s: torch.Tensor, + psd_n: torch.Tensor, + mask_s: torch.Tensor, + mask_n: torch.Tensor, + reference_vector: torch.Tensor, + solution: str = 'ref_channel', + diagonal_loading: bool = True, + diag_eps: float = 1e-7, + eps: float = 1e-8, + ) -> torch.Tensor: + r"""Recursively update the MVDR beamforming vector. + + Args: + psd_s (torch.Tensor): psd matrix of target speech + psd_n (torch.Tensor): psd matrix of noise + mask_s (torch.Tensor): T-F mask of target speech + mask_n (torch.Tensor): T-F mask of noise + reference_vector (torch.Tensor): one-hot reference channel matrix + solution (str, optional): the solution to estimate the beamforming weight + (Default: ``ref_channel``) + diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n + (Default: ``True``) + diag_eps (float, optional): The coefficient multipied to the identity matrix for diagonal loading + (Default: 1e-7) + eps (float, optional): a value added to the denominator in mask normalization. (Default: 1e-8) + + Returns: + Tensor: the mvdr beamforming weight matrix + """ + if self.multi_mask: + # Averaging mask along channel dimension + mask_s = mask_s.mean(dim=-3) # (..., freq, time) + mask_n = mask_n.mean(dim=-3) # (..., freq, time) + if self.psd_s.ndim == 1: + self.psd_s = psd_s + self.psd_n = psd_n + self.mask_sum_s = mask_s.sum(dim=-1) + self.mask_sum_n = mask_n.sum(dim=-1) + return self._get_mvdr_vector(psd_s, psd_n, reference_vector, solution, diagonal_loading, diag_eps, eps) + else: + psd_s = self._get_updated_psd_speech(psd_s, mask_s) + psd_n = self._get_updated_psd_noise(psd_n, mask_n) + self.psd_s = psd_s + self.psd_n = psd_n + self.mask_sum_s = self.mask_sum_s + mask_s.sum(dim=-1) + self.mask_sum_n = self.mask_sum_n + mask_n.sum(dim=-1) + return self._get_mvdr_vector(psd_s, psd_n, reference_vector, solution, diagonal_loading, diag_eps, eps) + + def _get_updated_psd_speech(self, psd_s: torch.Tensor, mask_s: torch.Tensor) -> torch.Tensor: + r"""Update psd of speech recursively. + + Args: + psd_s (torch.Tensor): psd matrix of target speech + mask_s (torch.Tensor): T-F mask of target speech + + Returns: + torch.Tensor: the updated psd of speech + """ + numerator = self.mask_sum_s / (self.mask_sum_s + mask_s.sum(dim=-1)) + denominator = 1 / (self.mask_sum_s + mask_s.sum(dim=-1)) + psd_s = self.psd_s * numerator[..., None, None] + psd_s * denominator[..., None, None] + return psd_s + + def _get_updated_psd_noise(self, psd_n: torch.Tensor, mask_n: torch.Tensor) -> torch.Tensor: + r"""Update psd of noise recursively. + + Args: + psd_n (torch.Tensor): psd matrix of target noise + mask_n (torch.Tensor): T-F mask of target noise + + Returns: + torch.Tensor: the updated psd of noise + """ + numerator = self.mask_sum_n / (self.mask_sum_n + mask_n.sum(dim=-1)) + denominator = 1 / (self.mask_sum_n + mask_n.sum(dim=-1)) + psd_n = self.psd_n * numerator[..., None, None] + psd_n * denominator[..., None, None] + return psd_n + + def _get_mvdr_vector( + self, + psd_s: torch.Tensor, + psd_n: torch.Tensor, + reference_vector: torch.Tensor, + solution: str = 'ref_channel', + diagonal_loading: bool = True, + diag_eps: float = 1e-7, + eps: float = 1e-8, + ) -> torch.Tensor: + r"""Compute beamforming vector by the reference channel selection method. + + Args: + psd_s (torch.Tensor): psd matrix of target speech + psd_n (torch.Tensor): psd matrix of noise + reference_vector (torch.Tensor): one-hot reference channel matrix + solution (str, optional): the solution to estimate the beamforming weight + (Default: ``ref_channel``) + diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n + (Default: ``True``) + diag_eps (float, optional): The coefficient multipied to the identity matrix for diagonal loading + (Default: 1e-7) + eps (float, optional): a value added to the denominator in mask normalization. Default: 1e-8 + + Returns: + torch.Tensor: the mvdr beamforming weight matrix + """ + if diagonal_loading: + psd_n = self._tik_reg(psd_n, reg=diag_eps, eps=eps) + if solution == "ref_channel": + numerator = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s + # ws: (..., C, C) / (...,) -> (..., C, C) + ws = numerator / (_get_mat_trace(numerator)[..., None, None] + eps) + # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) + beamform_vector = torch.einsum("...fec,...c->...fe", [ws, reference_vector]) + else: + if solution == "stv_evd": + stv = self._get_steering_vector_evd(psd_s) + else: + stv = self._get_steering_vector_power(psd_s, psd_n, reference_vector) + # numerator = psd_n.inv() @ stv + numerator = torch.linalg.solve(psd_n, stv).squeeze(-1) # (..., freq, channel) + # denominator = stv^H @ psd_n.inv() @ stv + denominator = torch.einsum("...d,...d->...", [stv.conj().squeeze(-1), numerator]) + # normalzie the numerator + scale = stv.squeeze(-1)[..., self.ref_channel, None].conj() + beamform_vector = numerator * scale / (denominator.real.unsqueeze(-1) + eps) + + return beamform_vector + + def _get_steering_vector_evd(self, psd_s: torch.Tensor) -> torch.Tensor: + r"""Estimate the steering vector by eigenvalue decomposition. + + Args: + psd_s (torch.tensor): covariance matrix of speech + Tensor of dimension `(..., freq, channel, channel)` + + Returns: + torch.Tensor: the enhanced STFT + Tensor of dimension `(..., freq, channel, 1)` + """ + w, v = torch.linalg.eig(psd_s) # (..., freq, channel, channel) + _, indices = torch.max(w.abs(), dim=-1, keepdim=True) + indices = indices.unsqueeze(-1) + stv = v.gather(-1, indices.expand(psd_s.shape[:-1] + (1,))) # (..., freq, channel, 1) + return stv + + def _get_steering_vector_power( + self, + psd_s: torch.Tensor, + psd_n: torch.Tensor, + reference_vector: torch.Tensor + ) -> torch.Tensor: + r"""Estimate the steering vector by the power method. + + Args: + psd_s (torch.tensor): covariance matrix of speech + Tensor of dimension `(..., freq, channel, channel)` + psd_n (torch.Tensor): covariance matrix of noise + Tensor of dimension `(..., freq, channel, channel)` + reference_vector (torch.Tensor): one-hot reference channel matrix + + Returns: + torch.Tensor: the enhanced STFT + Tensor of dimension `(..., freq, channel, 1)` + """ + phi = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s + stv = torch.einsum("...fec,...c->...fe", [phi, reference_vector]) + stv = stv.unsqueeze(-1) + stv = torch.matmul(phi, stv) + stv = torch.matmul(psd_s, stv) + return stv + + def _apply_beamforming_vector( + self, + specgram: torch.Tensor, + beamform_vector: torch.Tensor + ) -> torch.Tensor: + r"""Apply the beamforming weight to the noisy STFT + Args: + specgram (torch.tensor): multi-channel noisy STFT + Tensor of dimension `(..., channel, freq, time)` + beamform_vector (torch.Tensor): beamforming weight matrix + Tensor of dimension `(..., freq, channel)` + + Returns: + torch.Tensor: the enhanced STFT + Tensor of dimension `(..., freq, time)` + """ + # (..., channel) x (..., channel, freq, time) -> (..., freq, time) + specgram_enhanced = torch.einsum("...fc,...cft->...ft", [beamform_vector.conj(), specgram]) + return specgram_enhanced + + def _tik_reg( + self, + mat: torch.Tensor, + reg: float = 1e-7, + eps: float = 1e-8 + ) -> torch.Tensor: + """Perform Tikhonov regularization (only modifying real part). + Args: + mat (torch.Tensor): input matrix (..., channel, channel) + reg (float, optional): regularization factor (Default: 1e-8) + eps (float, optional): a value to avoid the correlation matrix is all-zero (Default: 1e-8) + + Returns: + torch.Tensor: regularized matrix (..., channel, channel) + """ + # Add eps + C = mat.size(-1) + eye = torch.eye(C, dtype=mat.dtype, device=mat.device) + with torch.no_grad(): + epsilon = _get_mat_trace(mat).real[..., None, None] * reg + # in case that correlation_matrix is all-zero + epsilon = epsilon + eps + mat = mat + epsilon * eye[..., :, :] + return mat + + def forward( + self, + specgram: torch.Tensor, + mask_s: torch.Tensor, + mask_n: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Perform MVDR beamforming. + + Args: + specgram (torch.Tensor): the multi-channel STF of the noisy speech. + Tensor of dimension `(..., channel, freq, time)` + mask_s (torch.Tensor): Time-Frequency mask of target speech. + Tensor of dimension `(..., freq, time)` if multi_mask is ``False`` + or or dimension `(..., channel, freq, time)` if multi_mask is ``True`` + mask_n (torch.Tensor or None, optional): Time-Frequency mask of noise. + Tensor of dimension `(..., freq, time)` if multi_mask is ``False`` + or or dimension `(..., channel, freq, time)` if multi_mask is ``True`` + (Default: None) + + Returns: + torch.Tensor: The single-channel STFT of the enhanced speech. + Tensor of dimension `(..., freq, time)` + """ + if specgram.ndim < 3: + raise ValueError( + f"Expected at least 3D tensor (..., channel, freq, time). Found: {specgram.shape}" + ) + if specgram.dtype != torch.cdouble: + raise ValueError( + f"The type of ``specgram`` tensor must be ``torch.cdouble``. Found: {specgram.dtype}" + ) + + if mask_n is None: + warnings.warn( + "``mask_n`` is not provided, use ``1 - mask_s`` as ``mask_n``." + ) + mask_n = 1 - mask_s + + shape = specgram.size() + + # pack batch + specgram = specgram.reshape(-1, shape[-3], shape[-2], shape[-1]) + if self.multi_mask: + mask_s = mask_s.reshape(-1, shape[-3], shape[-2], shape[-1]) + mask_n = mask_n.reshape(-1, shape[-3], shape[-2], shape[-1]) + else: + mask_s = mask_s.reshape(-1, shape[-2], shape[-1]) + mask_n = mask_n.reshape(-1, shape[-2], shape[-1]) + + psd_s = self.psd(specgram, mask_s) # (..., freq, time, channel, channel) + psd_n = self.psd(specgram, mask_n) # (..., freq, time, channel, channel) + + u = torch.zeros( + specgram.size()[:-2], + device=specgram.device, + dtype=torch.cdouble + ) # (..., channel) + u[..., self.ref_channel].fill_(1) + + if self.online: + w_mvdr = self._get_updated_mvdr_vector( + psd_s, + psd_n, + mask_s, + mask_n, + u, + self.solution, + self.diag_loading, + self.diag_eps + ) + else: + w_mvdr = self._get_mvdr_vector( + psd_s, + psd_n, + u, + self.solution, + self.diag_loading, + self.diag_eps + ) + + specgram_enhanced = self._apply_beamforming_vector(specgram, w_mvdr) + + # unpack batch + specgram_enhanced = specgram_enhanced.reshape(shape[:-3] + shape[-2:]) + + return specgram_enhanced diff --git a/torchaudio/utils/__init__.py b/torchaudio/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..578a51925a459dfa5ec056d504437f9ee40b4840 --- /dev/null +++ b/torchaudio/utils/__init__.py @@ -0,0 +1,8 @@ +from . import ( + sox_utils, +) +from torchaudio._internal import module_utils as _mod_utils + + +if _mod_utils.is_sox_available(): + sox_utils.set_verbosity(1) diff --git a/torchaudio/utils/sox_utils.py b/torchaudio/utils/sox_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..384f370a4d5b0e486faa2473c875c780589ab61a --- /dev/null +++ b/torchaudio/utils/sox_utils.py @@ -0,0 +1,102 @@ +from typing import List, Dict + +import torch +from torchaudio._internal import module_utils as _mod_utils + + +@_mod_utils.requires_sox() +def set_seed(seed: int): + """Set libsox's PRNG + + Args: + seed (int): seed value. valid range is int32. + + See Also: + http://sox.sourceforge.net/sox.html + """ + torch.ops.torchaudio.sox_utils_set_seed(seed) + + +@_mod_utils.requires_sox() +def set_verbosity(verbosity: int): + """Set libsox's verbosity + + Args: + verbosity (int): Set verbosity level of libsox. + + * ``1`` failure messages + * ``2`` warnings + * ``3`` details of processing + * ``4``-``6`` increasing levels of debug messages + + See Also: + http://sox.sourceforge.net/sox.html + """ + torch.ops.torchaudio.sox_utils_set_verbosity(verbosity) + + +@_mod_utils.requires_sox() +def set_buffer_size(buffer_size: int): + """Set buffer size for sox effect chain + + Args: + buffer_size (int): Set the size in bytes of the buffers used for processing audio. + + See Also: + http://sox.sourceforge.net/sox.html + """ + torch.ops.torchaudio.sox_utils_set_buffer_size(buffer_size) + + +@_mod_utils.requires_sox() +def set_use_threads(use_threads: bool): + """Set multithread option for sox effect chain + + Args: + use_threads (bool): When ``True``, enables ``libsox``'s parallel effects channels processing. + To use mutlithread, the underlying ``libsox`` has to be compiled with OpenMP support. + + See Also: + http://sox.sourceforge.net/sox.html + """ + torch.ops.torchaudio.sox_utils_set_use_threads(use_threads) + + +@_mod_utils.requires_sox() +def list_effects() -> Dict[str, str]: + """List the available sox effect names + + Returns: + Dict[str, str]: Mapping from ``effect name`` to ``usage`` + """ + return dict(torch.ops.torchaudio.sox_utils_list_effects()) + + +@_mod_utils.requires_sox() +def list_read_formats() -> List[str]: + """List the supported audio formats for read + + Returns: + List[str]: List of supported audio formats + """ + return torch.ops.torchaudio.sox_utils_list_read_formats() + + +@_mod_utils.requires_sox() +def list_write_formats() -> List[str]: + """List the supported audio formats for write + + Returns: + List[str]: List of supported audio formats + """ + return torch.ops.torchaudio.sox_utils_list_write_formats() + + +@_mod_utils.requires_sox() +def get_buffer_size() -> int: + """Get buffer size for sox effect chain + + Returns: + int: size in bytes of buffers used for processing audio. + """ + return torch.ops.torchaudio.sox_utils_get_buffer_size()