Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c26b46de
Commit
c26b46de
authored
Dec 13, 2022
by
Anthony Chang
Browse files
format
parent
15f1d4ad
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
44 deletions
+21
-44
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+0
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+4
-5
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+17
-38
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
c26b46de
...
@@ -45,7 +45,6 @@ Kernel outputs:
...
@@ -45,7 +45,6 @@ Kernel outputs:
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
c26b46de
...
@@ -384,7 +384,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -384,7 +384,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
ygrad_grid_desc_m_o
,
ygrad_grid_desc_m_o
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Y_M0
,
Y_M1
)),
make_pass_through_transform
(
O
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Y_M0
,
Y_M1
)),
make_pass_through_transform
(
O
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
...
@@ -456,7 +457,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -456,7 +457,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
//
//
// static auto MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock()
// static auto MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock()
//
//
// dQ = alpha * dS * K
// dQ = alpha * dS * K
//
//
...
@@ -514,7 +514,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -514,7 +514,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
VGradGridDesc_N_O
=
decltype
(
MakeVGradGridDescriptor_N_O
({},
{}));
using
VGradGridDesc_N_O
=
decltype
(
MakeVGradGridDescriptor_N_O
({},
{}));
using
YGradGridDesc_M0_O_M1
=
decltype
(
MakeYGradGridDescriptor_M0_O_M1
(
YGridDesc_M_O
{}));
using
YGradGridDesc_M0_O_M1
=
decltype
(
MakeYGradGridDescriptor_M0_O_M1
(
YGridDesc_M_O
{}));
constexpr
static
auto
make_MaskOutPredicate
()
constexpr
static
auto
make_MaskOutPredicate
()
...
@@ -700,8 +700,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -700,8 +700,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
vgrad_grid_desc_n_o_
{
DeviceOp
::
MakeVGradGridDescriptor_N_O
(
vgrad_grid_desc_n_o_
{
DeviceOp
::
MakeVGradGridDescriptor_N_O
(
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
/* PTrans descriptor will be constructed in kernel */
/* PTrans descriptor will be constructed in kernel */
ygrad_grid_desc_m0_o_m1_
{
ygrad_grid_desc_m0_o_m1_
{
DeviceOp
::
MakeYGradGridDescriptor_M0_O_M1
(
y_grid_desc_m_o_
)},
DeviceOp
::
MakeYGradGridDescriptor_M0_O_M1
(
y_grid_desc_m_o_
)},
// batch offsets
// batch offsets
a_grid_desc_g_m_k_
{
a_grid_desc_g_m_k_
{
Transform
::
MakeAGridDescriptor_G_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
Transform
::
MakeAGridDescriptor_G_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
c26b46de
...
@@ -214,43 +214,22 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -214,43 +214,22 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// PGrad Gemm
// PGrad Gemm
struct
PGradGemmTile_M_N_O_
struct
PGradGemmTile_M_N_O_
{
{
};
};
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
struct
YDotYGrad_M_O_
struct
YDotYGrad_M_O_
{
{
static
constexpr
index_t
SrcScalarPerVetor
=
16
/
sizeof
(
DataType
);
static
constexpr
index_t
SrcScalarPerVetor
=
16
/
sizeof
(
DataType
);
static
constexpr
auto
ThreadClusterLength_O
=
Number
<
BlockSliceLength_O_
/
SrcScalarPerVetor
>
{};
static
constexpr
auto
ThreadClusterLength_O
=
Number
<
BlockSliceLength_O_
/
SrcScalarPerVetor
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
static
constexpr
auto
ThreadSliceLength_O
=
Number
<
SrcScalarPerVetor
>
{};
static
constexpr
auto
ThreadSliceLength_O
=
Number
<
SrcScalarPerVetor
>
{};
static
constexpr
auto
ThreadSliceLength_M
=
static
constexpr
auto
ThreadSliceLength_M
=
Number
<
BlockSliceLength_M_
*
ThreadClusterLength_O
/
BlockSize_
>
{};
Number
<
BlockSliceLength_M_
*
ThreadClusterLength_O
/
BlockSize_
>
{};
// static_assert(BlockSliceLength_O_ % SrcScalarPerVetor == 0, "");
// static_assert(BlockSize_ % ThreadClusterLength_O == 0, "");
static_assert
(
ThreadClusterLength_O
*
ThreadSliceLength_O
==
BlockSliceLength_O_
,
""
);
static_assert
(
ThreadClusterLength_O
*
ThreadSliceLength_O
==
BlockSliceLength_O_
,
""
);
static_assert
(
ThreadClusterLength_M
*
ThreadSliceLength_M
==
BlockSliceLength_M_
,
""
);
static_assert
(
ThreadClusterLength_M
*
ThreadSliceLength_M
==
BlockSliceLength_M_
,
""
);
using
BlockwiseSumReduce
=
PartitionedBlockwiseReduction
<
FloatGemmAcc
,
BlockSize_
,
Sequence
<
ThreadClusterLength_M
,
ThreadClusterLength_O
>
,
Sequence
<
0
,
1
>
,
reduce
::
Add
,
false
>
;
// propagateNaN
// using ThreadReduceSrcDesc_M_O = decltype(make_naive_tensor_descriptor_packed(
// make_tuple(ThreadSliceLength_M, ThreadSliceLength_O)));
// using ThreadReduceDstDesc_M =
// decltype(make_naive_tensor_descriptor_packed(make_tuple(ThreadSliceLength_M)));
// using ThreadwiseSumReduce =
// ThreadwiseReduction<FloatGemmAcc,
// ThreadReduceSrcDesc_M_O,
// ThreadReduceDstDesc_M,
// reduce::Add,
// false>; // propagateNaN
using
SrcBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
using
SrcBufType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
,
DataType
,
ThreadSliceLength_M
*
ThreadSliceLength_O
,
ThreadSliceLength_M
*
ThreadSliceLength_O
,
...
@@ -1271,11 +1250,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1271,11 +1250,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
//
//
// dP
// dP
//
//
constexpr
auto
y_thread_desc_m0_m1_o0_o1
=
constexpr
auto
y_thread_desc_m0_m1_o0_o1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
I1
,
YDotYGrad_M_O
::
ThreadSliceLength_O
));
YDotYGrad_M_O
::
ThreadSliceLength_M
,
I1
,
YDotYGrad_M_O
::
ThreadSliceLength_O
));
constexpr
auto
y_thread_cluster_desc
=
constexpr
auto
y_thread_cluster_desc
=
make_cluster_descriptor
(
Sequence
<
I1
,
make_cluster_descriptor
(
Sequence
<
I1
,
YDotYGrad_M_O
::
ThreadClusterLength_M
,
YDotYGrad_M_O
::
ThreadClusterLength_M
,
...
@@ -1316,10 +1292,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1316,10 +1292,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr
auto
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
=
constexpr
auto
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
=
make_naive_tensor_descriptor
(
make_tuple
(
I1
,
P_M0
,
P_M1
,
P_M2
),
make_naive_tensor_descriptor
(
make_tuple
(
I1
,
P_M0
,
P_M1
,
P_M2
),
make_tuple
(
P_M0
*
P_M1
*
P_M2
,
P_M1
*
P_M2
,
P_M2
,
I1
));
make_tuple
(
P_M0
*
P_M1
*
P_M2
,
P_M1
*
P_M2
,
P_M2
,
I1
));
// y_dot_ygrad thread buffer for calculating sgrad; reuse LSE thread descriptor because
// per-thread LSE data and y_dot_ygrad is tiled the same way
constexpr
auto
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
=
constexpr
auto
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
=
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
;
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
;
// reuse LSE thread descriptor because
// per-thread LSE data and y_dot_ygrad is
// tiled the same way
// TODO ANT: dP Gemm can reuse first blockwise gemm and pipeline
auto
y_dot_ygrad_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
auto
y_dot_ygrad_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatGemmAcc
,
FloatGemmAcc
,
...
@@ -1404,7 +1382,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1404,7 +1382,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
block_sync_lds
();
block_sync_lds
();
static_for
<
0
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
1
>
{}([
&
](
auto
iM
)
{
const
auto
idx_on_block
=
y_thread_data_on_block_idx
[
I1
]
+
iM
;
const
auto
idx_on_block
=
y_thread_data_on_block_idx
[
I1
]
+
iM
;
y_dot_ygrad_block_accum_buf
.
AtomicAdd
(
idx_on_block
,
true
,
y_dot_ygrad_thread_accum_buf
[
iM
]);
y_dot_ygrad_block_accum_buf
.
AtomicAdd
(
idx_on_block
,
true
,
y_dot_ygrad_thread_accum_buf
[
iM
]);
});
});
block_sync_lds
();
block_sync_lds
();
...
@@ -1416,7 +1395,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1416,7 +1395,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
}
#endif
#endif
// distribute
to threads
// distribute
y_dot_ygrad to threads; LDS accum buffer can be safely accessed after barrier
y_dot_ygrad_thread_copy_lds_to_vgpr
.
Run
(
y_dot_ygrad_thread_copy_lds_to_vgpr
.
Run
(
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
,
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl
,
y_dot_ygrad_block_accum_buf
,
y_dot_ygrad_block_accum_buf
,
...
@@ -1667,9 +1646,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1667,9 +1646,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
// TODO ANT:
// TODO ANT:
// shuffle dQ and write
// shuffle dQ and write
#if 0
#if 0
{
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
...
@@ -1865,7 +1844,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1865,7 +1844,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
}
});
});
}
}
#endif
#endif
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment