Commit 97e13de2 authored by Tri Dao's avatar Tri Dao
Browse files

Cast q.get_device() to char to avoid compiler warning (narrowing)

parent ed553e92
...@@ -253,7 +253,8 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -253,7 +253,8 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
bool loop = max_seqlen_k > blocksize_c; bool loop = max_seqlen_k > blocksize_c;
// Otherwise the kernel will be launched from cuda:0 device // Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.get_device()}; // Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options(); auto opts = q.options();
...@@ -412,7 +413,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -412,7 +413,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
bool loop = max_seqlen_k > blocksize_c; bool loop = max_seqlen_k > blocksize_c;
// Otherwise the kernel will be launched from cuda:0 device // Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.get_device()}; // Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
// It's possible the softmax_lse_ from the fwd has a different length since blocksize_c could be different. // It's possible the softmax_lse_ from the fwd has a different length since blocksize_c could be different.
auto softmax_lse = softmax_lse_.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, max_seqlen_q)}).contiguous(); auto softmax_lse = softmax_lse_.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, max_seqlen_q)}).contiguous();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment