Unverified Commit 6a4b834f authored by Wei Kang's avatar Wei Kang Committed by GitHub
Browse files

Fix actions (#31)

* Fix actions

* change to py39

* Fix black

* More fixes

* Fixes

* Fix torch version

* Fix cudnn
parent 285ad4dd
[flake8]
show-source=true
statistics=true
max-line-length=80
exclude =
.git,
.github,
setup.py,
build,
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
......@@ -14,8 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
torch=$TORCH_VERSION
cuda=$CUDA_VERSION
echo "torch version: $torch"
echo "cuda version: $cuda"
case ${torch} in
1.5.*)
case ${cuda} in
......
......@@ -49,9 +49,9 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
torch: ["1.13.1"]
torchaudio: ["0.13.1"]
python-version: ["3.11"]
torch: ["1.12.1"]
torchaudio: ["0.12.1"]
python-version: ["3.9"]
build_type: ["Release", "Debug"]
steps:
......@@ -81,7 +81,7 @@ jobs:
run: |
python3 -m pip install -qq --upgrade pip
python3 -m pip install -qq torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install -qq torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
python3 -m pip install -qq torchaudio==${{ matrix.torchaudio }} -f https://download.pytorch.org/whl/cpu/torch_stable.html
python3 -c "import torch; print('torch version:', torch.__version__)"
python3 -m torch.utils.collect_env
......@@ -92,7 +92,7 @@ jobs:
run: |
python3 -m pip install -qq --upgrade pip
python3 -m pip install -qq torch==${{ matrix.torch }}
python3 -m pip install -qq torch==${{ matrix.torchaudio }}
python3 -m pip install -qq torchaudio==${{ matrix.torchaudio }}
python3 -c "import torch; print('torch version:', torch.__version__)"
python3 -m torch.utils.collect_env
......
......@@ -47,10 +47,9 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest]
cuda: ["11.7"]
torch: ["1.13.1"]
torchaudio: ["1.13.0"]
python-version: ["3.11"]
cuda: ["11.6"]
torch: ["1.12.1"]
python-version: ["3.9"]
build_type: ["Release", "Debug"]
steps:
......@@ -103,7 +102,7 @@ jobs:
env:
cuda: ${{ matrix.cuda }}
run: |
./scripts/github_actions/install_cudnn.sh
./.github/scripts/install_cudnn.sh
- name: Configure CMake
shell: bash
......
......@@ -66,4 +66,4 @@ jobs:
shell: bash
working-directory: ${{github.workspace}}
run: |
black --check --diff .
black -l 80 --check --diff .
......@@ -11,6 +11,7 @@ endif()
cmake_minimum_required(VERSION 3.8 FATAL_ERROR)
set(CMAKE_DISABLE_FIND_PACKAGE_MKL TRUE)
set(languages CXX)
set(_FT_WITH_CUDA ON)
......
......@@ -16,11 +16,15 @@ endif()
pybind11_add_module(_fast_rnnt ${fast_rnnt_srcs})
target_link_libraries(_fast_rnnt PRIVATE mutual_information_core)
if(UNIX AND NOT APPLE)
if(APPLE)
target_link_libraries(_fast_rnnt
PRIVATE
${TORCH_DIR}/lib/libtorch_python.dylib
)
elseif(UNIX)
target_link_libraries(_fast_rnnt
PRIVATE
${PYTHON_LIBRARY}
${TORCH_DIR}/lib/libtorch_python.so
)
endif()
......@@ -13,5 +13,3 @@ from .rnnt_loss import rnnt_loss
from .rnnt_loss import rnnt_loss_pruned
from .rnnt_loss import rnnt_loss_simple
from .rnnt_loss import rnnt_loss_smoothed
......@@ -160,7 +160,8 @@ class MutualInformationRecursionFunction(torch.autograd.Function):
if return_grad or px.requires_grad or py.requires_grad:
ans_grad = torch.ones(B, device=px.device, dtype=px.dtype)
(px_grad, py_grad) = _fast_rnnt.mutual_information_backward(
px, py, boundary, p, ans_grad)
px, py, boundary, p, ans_grad
)
ctx.save_for_backward(px_grad, py_grad)
assert len(pxy_grads) == 2
pxy_grads[0] = px_grad
......@@ -290,8 +291,9 @@ def mutual_information_recursion(
px, py = px.contiguous(), py.contiguous()
pxy_grads = [None, None]
scores = MutualInformationRecursionFunction.apply(px, py, pxy_grads,
boundary, return_grad)
scores = MutualInformationRecursionFunction.apply(
px, py, pxy_grads, boundary, return_grad
)
px_grad, py_grad = pxy_grads
return (scores, (px_grad, py_grad)) if return_grad else scores
......@@ -388,16 +390,18 @@ def joint_mutual_information_recursion(
p = torch.empty(B, S + 1, T + 1, device=px_tot.device, dtype=px_tot.dtype)
# note, tot_probs is without grad.
tot_probs = _fast_rnnt.mutual_information_forward(px_tot, py_tot, boundary, p)
tot_probs = _fast_rnnt.mutual_information_forward(
px_tot, py_tot, boundary, p
)
# this is a kind of "fake gradient" that we use, in effect to compute
# occupation probabilities. The backprop will work regardless of the
# actual derivative w.r.t. the total probs.
ans_grad = torch.ones(B, device=px_tot.device, dtype=px_tot.dtype)
(px_grad,
py_grad) = _fast_rnnt.mutual_information_backward(px_tot, py_tot, boundary, p,
ans_grad)
(px_grad, py_grad) = _fast_rnnt.mutual_information_backward(
px_tot, py_tot, boundary, p, ans_grad
)
px_grad = px_grad.reshape(1, B, -1)
py_grad = py_grad.reshape(1, B, -1)
......
......@@ -170,7 +170,7 @@ def get_rnnt_logprobs(
am.transpose(1, 2), # (B, C, T)
dim=1,
index=symbols.unsqueeze(2).expand(B, S, T),
) # (B, S, T)
) # (B, S, T)
if rnnt_type == "regular":
px_am = torch.cat(
......@@ -291,7 +291,9 @@ def rnnt_loss_simple(
T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
(T - 1) / 2,
dtype=px.dtype,
device=px.device,
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
......@@ -495,7 +497,9 @@ def rnnt_loss(
T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
(T - 1) / 2,
dtype=px.dtype,
device=px.device,
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
......@@ -770,9 +774,7 @@ def do_rnnt_pruning(
lm_pruning = torch.gather(
lm,
dim=1,
index=ranges.reshape(B, T * s_range, 1).expand(
(B, T * s_range, C)
),
index=ranges.reshape(B, T * s_range, 1).expand((B, T * s_range, C)),
).reshape(B, T, s_range, C)
return am_pruning, lm_pruning
......@@ -1057,7 +1059,9 @@ def rnnt_loss_pruned(
T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
(T - 1) / 2,
dtype=px.dtype,
device=px.device,
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
......@@ -1248,7 +1252,7 @@ def get_rnnt_logprobs_smoothed(
am.transpose(1, 2), # (B, C, T)
dim=1,
index=symbols.unsqueeze(2).expand(B, S, T),
) # (B, S, T)
) # (B, S, T)
if rnnt_type == "regular":
px_am = torch.cat(
......@@ -1413,7 +1417,9 @@ def rnnt_loss_smoothed(
T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
(T - 1) / 2,
dtype=px.dtype,
device=px.device,
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
......
......@@ -39,7 +39,7 @@ class BuildExtension(build_ext):
cmake_args = "-DCMAKE_BUILD_TYPE=Release -DFT_BUILD_TESTS=OFF"
if make_args == "" and system_make_args == "":
make_args = ' -j '
make_args = " -j "
if "PYTHON_EXECUTABLE" not in cmake_args:
print(f"Setting PYTHON_EXECUTABLE to {sys.executable}")
......@@ -89,17 +89,17 @@ def get_package_version():
latest_version = latest_version.strip('"')
return latest_version
def get_requirements():
with open("requirements.txt", encoding="utf8") as f:
requirements = f.read().splitlines()
return requirements
package_name = "fast_rnnt"
with open(
"fast_rnnt/python/fast_rnnt/__init__.py", "a"
) as f:
with open("fast_rnnt/python/fast_rnnt/__init__.py", "a") as f:
f.write(f"__version__ = '{get_package_version()}'\n")
setuptools.setup(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment