Unverified Commit ce58fd6f authored by Shilong Zhang's avatar Shilong Zhang Committed by GitHub
Browse files

add device DeviceGuard (#1402)

parent c5018463
...@@ -39,6 +39,7 @@ Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes, ...@@ -39,6 +39,7 @@ Tensor ms_deform_attn_forward(const Tensor &value, const Tensor &spatial_shapes,
CHECK_CUDA_INPUT(level_start_index) CHECK_CUDA_INPUT(level_start_index)
CHECK_CUDA_INPUT(sampling_loc) CHECK_CUDA_INPUT(sampling_loc)
CHECK_CUDA_INPUT(attn_weight) CHECK_CUDA_INPUT(attn_weight)
at::DeviceGuard guard(value.device());
return ms_deform_attn_cuda_forward(value, spatial_shapes, level_start_index, return ms_deform_attn_cuda_forward(value, spatial_shapes, level_start_index,
sampling_loc, attn_weight, im2col_step); sampling_loc, attn_weight, im2col_step);
#else #else
...@@ -66,6 +67,7 @@ void ms_deform_attn_backward(const Tensor &value, const Tensor &spatial_shapes, ...@@ -66,6 +67,7 @@ void ms_deform_attn_backward(const Tensor &value, const Tensor &spatial_shapes,
CHECK_CUDA_INPUT(grad_value) CHECK_CUDA_INPUT(grad_value)
CHECK_CUDA_INPUT(grad_sampling_loc) CHECK_CUDA_INPUT(grad_sampling_loc)
CHECK_CUDA_INPUT(grad_attn_weight) CHECK_CUDA_INPUT(grad_attn_weight)
at::DeviceGuard guard(value.device());
ms_deform_attn_cuda_backward(value, spatial_shapes, level_start_index, ms_deform_attn_cuda_backward(value, spatial_shapes, level_start_index,
sampling_loc, attn_weight, grad_output, sampling_loc, attn_weight, grad_output,
grad_value, grad_sampling_loc, grad_value, grad_sampling_loc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment