Commit c76d952e authored by Chin-Yun Yu's avatar Chin-Yun Yu Committed by Facebook GitHub Bot
Browse files

feat: add guard in `lfilter` for a non-default cuda device (#3432)

Summary:
Should resolve https://github.com/pytorch/audio/issues/3425

cc mthrok

Pull Request resolved: https://github.com/pytorch/audio/pull/3432

Differential Revision: D46656180

Pulled By: mthrok

fbshipit-source-id: 5c534bee2f143ef5cb5e50ec74828012dbcab7e9
parent c5877157
#include <c10/cuda/CUDAException.h> #include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/torch.h> #include <torch/torch.h>
template <typename scalar_t> template <typename scalar_t>
...@@ -58,6 +59,8 @@ void cuda_lfilter_core_loop( ...@@ -58,6 +59,8 @@ void cuda_lfilter_core_loop(
TORCH_CHECK(in.size(2) + a_flipped.size(1) - 1 == padded_out.size(2)); TORCH_CHECK(in.size(2) + a_flipped.size(1) - 1 == padded_out.size(2));
const at::cuda::OptionalCUDAGuard device_guard(device_of(in));
const dim3 threads(256); const dim3 threads(256);
const dim3 blocks((N * C + threads.x - 1) / threads.x); const dim3 blocks((N * C + threads.x - 1) / threads.x);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment