Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
c0daa62e
Commit
c0daa62e
authored
Jun 26, 2022
by
Tri Dao
Browse files
Add type check (fp16) in the forward pass
parent
ea38d3d2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
0 deletions
+3
-0
csrc/flash_attn/fmha_api.cpp
csrc/flash_attn/fmha_api.cpp
+3
-0
No files found.
csrc/flash_attn/fmha_api.cpp
View file @
c0daa62e
...
@@ -130,6 +130,9 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
...
@@ -130,6 +130,9 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
bool
is_dropout
=
p_dropout
>
0.0
;
bool
is_dropout
=
p_dropout
>
0.0
;
Launch_params
<
Fused_multihead_attention_fprop_params
>
launch_params
(
dprops
,
stream
,
is_dropout
,
return_softmax
);
Launch_params
<
Fused_multihead_attention_fprop_params
>
launch_params
(
dprops
,
stream
,
is_dropout
,
return_softmax
);
TORCH_CHECK
(
qkv
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
cu_seqlens
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
qkv
.
is_cuda
())
TORCH_CHECK
(
qkv
.
is_cuda
())
TORCH_CHECK
(
cu_seqlens
.
is_cuda
())
TORCH_CHECK
(
cu_seqlens
.
is_cuda
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment