Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
97e13de2
Commit
97e13de2
authored
Oct 24, 2022
by
Tri Dao
Browse files
Cast q.get_device() to char to avoid compiler warning (narrowing)
parent
ed553e92
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
2 deletions
+4
-2
csrc/flash_attn/fmha_api.cpp
csrc/flash_attn/fmha_api.cpp
+4
-2
No files found.
csrc/flash_attn/fmha_api.cpp
View file @
97e13de2
...
@@ -253,7 +253,8 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
...
@@ -253,7 +253,8 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
bool
loop
=
max_seqlen_k
>
blocksize_c
;
bool
loop
=
max_seqlen_k
>
blocksize_c
;
// Otherwise the kernel will be launched from cuda:0 device
// Otherwise the kernel will be launched from cuda:0 device
at
::
cuda
::
CUDAGuard
device_guard
{
q
.
get_device
()};
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q
.
get_device
()};
auto
opts
=
q
.
options
();
auto
opts
=
q
.
options
();
...
@@ -412,7 +413,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
...
@@ -412,7 +413,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
bool
loop
=
max_seqlen_k
>
blocksize_c
;
bool
loop
=
max_seqlen_k
>
blocksize_c
;
// Otherwise the kernel will be launched from cuda:0 device
// Otherwise the kernel will be launched from cuda:0 device
at
::
cuda
::
CUDAGuard
device_guard
{
q
.
get_device
()};
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q
.
get_device
()};
// It's possible the softmax_lse_ from the fwd has a different length since blocksize_c could be different.
// It's possible the softmax_lse_ from the fwd has a different length since blocksize_c could be different.
auto
softmax_lse
=
softmax_lse_
.
index
({
torch
::
indexing
::
Slice
(),
torch
::
indexing
::
Slice
(),
torch
::
indexing
::
Slice
(
torch
::
indexing
::
None
,
max_seqlen_q
)}).
contiguous
();
auto
softmax_lse
=
softmax_lse_
.
index
({
torch
::
indexing
::
Slice
(),
torch
::
indexing
::
Slice
(),
torch
::
indexing
::
Slice
(
torch
::
indexing
::
None
,
max_seqlen_q
)}).
contiguous
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment