Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
63c2d069
"vscode:/vscode.git/clone" did not exist on "be4e3133f74daaabed839edb05c37ce3beae54a9"
Commit
63c2d069
authored
Feb 21, 2023
by
danyao12
Browse files
Merge branch 'attn-bwd-develop' into attn-bwd-dropout-pt1
parents
a5bad9f2
82ce7f4e
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
24 additions
and
29 deletions
+24
-29
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+3
-3
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
+18
-23
library/include/ck/library/reference_tensor_operation/cpu/reference_dropout.hpp
...rary/reference_tensor_operation/cpu/reference_dropout.hpp
+1
-1
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
63c2d069
...
@@ -401,9 +401,9 @@ int run(int argc, char* argv[])
...
@@ -401,9 +401,9 @@ int run(int argc, char* argv[])
break
;
break
;
case
4
:
case
4
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_
Diagonal
<
DataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_
1
<
DataType
>
{
1
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_
Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_
1
<
DataType
>
{
1
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
break
;
break
;
case
5
:
case
5
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
View file @
63c2d069
...
@@ -44,7 +44,7 @@ struct BlockwiseDropout
...
@@ -44,7 +44,7 @@ struct BlockwiseDropout
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
in_thread_buf
(
offset
)
=
execute_dropout
(
tmp
[
tmp_index
]
<
p_dropout_16bits
,
in_thread_buf
(
offset
));
execute_dropout
(
tmp
[
tmp_index
]
<
=
p_dropout_16bits
,
in_thread_buf
(
offset
));
tmp_index
=
tmp_index
+
1
;
tmp_index
=
tmp_index
+
1
;
});
});
});
});
...
@@ -79,7 +79,7 @@ struct BlockwiseDropout
...
@@ -79,7 +79,7 @@ struct BlockwiseDropout
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
in_thread_buf
(
offset
)
=
execute_dropout
(
tmp
[
tmp_index
]
<
p_dropout_16bits
,
in_thread_buf
(
offset
));
execute_dropout
(
tmp
[
tmp_index
]
<
=
p_dropout_16bits
,
in_thread_buf
(
offset
));
z_thread_buf
(
offset
)
=
tmp
[
tmp_index
];
z_thread_buf
(
offset
)
=
tmp
[
tmp_index
];
tmp_index
=
tmp_index
+
1
;
tmp_index
=
tmp_index
+
1
;
});
});
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
View file @
63c2d069
...
@@ -1191,7 +1191,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1191,7 +1191,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
ushort
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
const
ushort
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
const
bool
is_dropout
=
p_drop
>
0.0
f
;
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
rp_dropout
);
rp_dropout
);
...
@@ -1866,8 +1865,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1866,8 +1865,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
// save z to global
// save z to global
if
(
is_dropout
)
{
if
(
p_z_grid
)
if
(
p_z_grid
)
{
{
// P_dropped
// P_dropped
...
@@ -1876,8 +1873,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1876,8 +1873,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
true
>(
true
>(
s_slash_p_thread_buf
,
ph
,
z_tenor_buffer
);
s_slash_p_thread_buf
,
ph
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
...
@@ -1889,7 +1885,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1889,7 +1885,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
true
>(
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
true
>(
s_slash_p_thread_buf
,
ph
);
s_slash_p_thread_buf
,
ph
);
}
}
}
block_sync_lds
();
// wait for gemm1 LDS read
block_sync_lds
();
// wait for gemm1 LDS read
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_dropout.hpp
View file @
63c2d069
...
@@ -48,7 +48,7 @@ struct ReferenceDropout : public device::BaseOperator
...
@@ -48,7 +48,7 @@ struct ReferenceDropout : public device::BaseOperator
{
{
arg
.
out_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
arg
.
out_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
self
(
idx
)
=
arg
.
ref_
(
idx
)
<
arg
.
p_dropout_in_16bits_
?
arg
.
in_
(
idx
)
*
arg
.
rp_dropout_
:
0
;
arg
.
ref_
(
idx
)
<
=
arg
.
p_dropout_in_16bits_
?
arg
.
in_
(
idx
)
*
arg
.
rp_dropout_
:
0
;
});
});
return
0
;
return
0
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment