Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d2eed8e6
Unverified
Commit
d2eed8e6
authored
Feb 17, 2023
by
guangzlu
Committed by
GitHub
Feb 17, 2023
Browse files
Merge branch 'attn-bwd-dropout' into fwd-drop-verify2
parents
043c8ff3
e9e6081a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
5 additions
and
1 deletion
+5
-1
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
..._softmax_gemm/run_grouped_multihead_attention_forward.inc
+3
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp
...vice_grouped_multihead_attention_forward_xdl_cshuffle.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+1
-0
No files found.
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
View file @
d2eed8e6
...
@@ -10,7 +10,8 @@ int run(int argc, char* argv[])
...
@@ -10,7 +10,8 @@ int run(int argc, char* argv[])
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
true
;
bool
output_permute
=
true
;
float
p_drop
=
0.1
;
float
p_drop
=
0.2
;
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
uint16_t
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
...
@@ -251,6 +252,7 @@ int run(int argc, char* argv[])
...
@@ -251,6 +252,7 @@ int run(int argc, char* argv[])
{
seed
,
offset
});
// dropout random seed and offset, offset should be
{
seed
,
offset
});
// dropout random seed and offset, offset should be
// at least the number of elements on a thread
// at least the number of elements on a thread
// specify workspace for problem_desc
// specify workspace for problem_desc
DeviceMem
problem_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
DeviceMem
problem_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_forward_xdl_cshuffle.hpp
View file @
d2eed8e6
...
@@ -623,6 +623,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -623,6 +623,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
const
auto
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
const
auto
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
z_grid_desc_m_n
);
z_grid_desc_m_n
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
d2eed8e6
...
@@ -1019,6 +1019,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1019,6 +1019,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
}
}
else
else
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment