Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
272b7574
Commit
272b7574
authored
Feb 21, 2023
by
danyao12
Browse files
fix drop==0 compiler issue in prototype1
parent
63c2d069
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
56 deletions
+21
-56
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1_fp16.cpp
...ax_gemm/batched_multihead_attention_backward_pt1_fp16.cpp
+3
-3
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
+18
-53
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_pt1_fp16.cpp
View file @
272b7574
...
@@ -400,9 +400,9 @@ int run(int argc, char* argv[])
...
@@ -400,9 +400,9 @@ int run(int argc, char* argv[])
break
;
break
;
case
4
:
case
4
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_
Diagonal
<
DataType
>
{});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_
1
<
DataType
>
{
1
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_
Diagonal
<
DataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_
1
<
DataType
>
{
1
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
break
;
break
;
case
5
:
case
5
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
DataType
>
{
1
});
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
View file @
272b7574
...
@@ -1265,7 +1265,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
...
@@ -1265,7 +1265,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
ushort
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
const
ushort
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
const
bool
is_dropout
=
p_drop
>
0.0
f
;
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
const
tensor_operation
::
element_wise
::
Scale
scale_rp_dropout
(
s_element_op
.
Value
()
*
rp_dropout
);
rp_dropout
);
...
@@ -1718,36 +1717,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
...
@@ -1718,36 +1717,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
block_work_idx
[
I0
],
I0
,
I0
/* all WGs start from o_block_idx = 0 */
,
I0
)
+
block_work_idx
[
I0
],
I0
,
I0
/* all WGs start from o_block_idx = 0 */
,
I0
)
+
y_thread_data_on_block_idx
;
y_thread_data_on_block_idx
;
// // performs for y
// auto y_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
// DataType,
// DataType,
// YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
// decltype(y_thread_desc_m0_m1_o0_o1),
// decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
// Sequence<0, 1, 2, 3>,
// 3, // SrcVectorDim
// YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
// 1, // SrcScalarStrideInVector
// true /* ResetCoordAfterRun */,
// true /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_oblock_operblock,
// y_thread_data_on_grid_idx);
// // performs for ygrad
// auto ygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
// DataType,
// DataType,
// decltype(YDotYGrad_M_O::ygrad_block_desc_m_o),
// decltype(ygrad_thread_desc_m_o),
// decltype(ygrad_thread_desc_m_o.GetLengths()),
// Sequence<0, 1>,
// 1, // SrcVectorDim
// YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
// 1, // SrcScalarStrideInVector
// true /* ResetCoordAfterRun */,
// true /* InvalidElementAsNaN */>(YDotYGrad_M_O::ygrad_block_desc_m_o,
// ygrad_thread_data_on_block_idx);
// performs for y
// performs for y
auto
y_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
auto
y_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
DataType
,
...
@@ -1986,29 +1955,25 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
...
@@ -1986,29 +1955,25 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
// save z to global
// save z to global
if
(
is_dropout
)
if
(
p_z_grid
)
{
{
if
(
p_z_grid
)
// P_dropped
{
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
// P_dropped
decltype
(
z_tenor_buffer
),
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
true
>(
decltype
(
z_tenor_buffer
),
s_slash_p_thread_buf
,
ph
,
z_tenor_buffer
);
true
>(
s_slash_p_thread_buf
,
ph
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_thread_copy_vgpr_to_global
.
Run
(
z_tenor_buffer
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_grid_buf
);
z_tenor_buffer
,
}
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
else
z_grid_buf
);
{
}
// P_dropped
else
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
true
>(
{
s_slash_p_thread_buf
,
ph
);
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
true
>(
s_slash_p_thread_buf
,
ph
);
}
}
}
block_sync_lds
();
// wait for gemm1 LDS read
block_sync_lds
();
// wait for gemm1 LDS read
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment