Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4ddda63b
Commit
4ddda63b
authored
Feb 16, 2023
by
aska-0096
Browse files
sanity check pass
parent
74f0d5de
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
109 additions
and
112 deletions
+109
-112
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
...emm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
+7
-7
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
...tmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
+1
-7
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+0
-14
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+63
-77
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+12
-7
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+26
-0
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
View file @
4ddda63b
...
...
@@ -108,26 +108,26 @@ using DeviceGemmInstance =
1
,
// MRepeat
8
,
// LRepeat
4
,
// NRepeat
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
4
,
64
,
1
>
,
// ABlockTransfer
MK -> K0 M K1
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// B0BlockTransfer
S
<
4
,
64
,
1
>
,
// B0BlockTransfer
LK -> K0 L K1
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// B1BlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
S
<
4
,
8
,
8
>
,
// B1BlockTransfer LN -> L0 N L1
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
1
,
false
,
1
,
// CShuffleMWmmaPerWavePerShuffle
2
,
// CShuffleNWmmaPerWavePerShuffle
...
...
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
View file @
4ddda63b
...
...
@@ -127,7 +127,7 @@ int run(int argc, char* argv[])
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
case
6
:
// Rand: b0 ; unit: a b1
fail
case
6
:
// Rand: b0 ; unit: a b1
pass
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
...
...
@@ -240,12 +240,6 @@ int run(int argc, char* argv[])
a_g_m_k
,
b0_g_k_n
,
acc0_g_m_n
,
a_element_op
,
b0_element_op
,
acc0_element_op
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// for(int i =0; i< 128; i++){
// for(int j =0; j< 128; j++){
// printf("%0.2lf ", acc0_g_m_n.mData[i*128 +j]);
// }
// printf("\n");
// }
// masking
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
N
);
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
4ddda63b
...
...
@@ -171,8 +171,6 @@ struct BlockwiseGemmWMMA
static_assert
(
MPerBlock
%
(
MPerWMMA
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
"wrong!"
);
// printf("tid %03d, Mat-B offset %d\n", get_thread_local_1d_id()%32, CalculateBThreadOriginDataIndex().At(Number<3>{}));
}
// transposed WMMA output C' = B' * A'
...
...
@@ -301,9 +299,6 @@ struct BlockwiseGemmWMMA
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
// static_for<0, a_thread_buf.size(), 1>{}([&](auto i) {
// a_thread_buf(i) = 1;
// });
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
...
...
@@ -323,9 +318,6 @@ struct BlockwiseGemmWMMA
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
,
n0
,
0
,
0
,
i
%
B_K1
))
>
{}];
// a_thread_vec.template AsType<FloatA>()(i) = 1;
// b_thread_vec.template AsType<FloatB>()(i) = 1;
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
...
...
@@ -334,12 +326,6 @@ struct BlockwiseGemmWMMA
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
// printf("GPU Gemm0 input, Tid %03d, A%2d = %04x, B%2d = %0x4\n",
// get_thread_local_1d_id(),
// i.value, *(reinterpret_cast<uint16_t*>(&a_thread_vec.template AsType<FloatA>()(i))),
// i.value, *(reinterpret_cast<uint16_t*>(&b_thread_vec.template AsType<FloatB>()(i))));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>()(
Number
<
0
>{}),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>()(
Number
<
0
>
{}),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
4ddda63b
...
...
@@ -658,16 +658,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
constexpr
auto
t_lwave
=
thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
At
(
I4
);
constexpr
auto
t_lsubgroup
=
thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
At
(
I5
);
constexpr
auto
t_laccvgprs
=
thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
At
(
I6
);
if
(
get_thread_local_1d_id
()
==
0
){
printf
(
"t_mrepeat %d, t_mwave %d, t_mthreadpersubgroup %d, t_lrepeat %d, t_lwave %d, t_lsubgroup %d, t_laccvgprs %d
\n
"
,
t_mrepeat
.
value
,
t_mwave
.
value
,
t_mthreadpersubgroup
.
value
,
t_lrepeat
.
value
,
t_lwave
.
value
,
t_lsubgroup
.
value
,
t_laccvgprs
.
value
);
}
// get acc0 thread map
constexpr
auto
m0_l_m1_to_m_l_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
t_mrepeat
*
t_mwave
,
t_mthreadpersubgroup
)),
...
...
@@ -744,27 +734,27 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// B1 matrix blockwise copy
auto
b1_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
B1ElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BL0
,
NPerBlock
,
BL1
>
,
B1BlockTransferThreadClusterLengths_L0_N_L1
,
B1BlockTransferThreadClusterArrangeOrder
,
FloatB1
,
FloatB1
,
decltype
(
b1_grid_desc_l0_n_l1
),
decltype
(
b1_block_desc_l0perblock_nperblock_l1
),
B1BlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
B1BlockTransferSrcVectorDim
,
2
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_L1
,
1
,
1
,
B1ThreadTransferSrcResetCoordinateAfterRun
,
true
,
// DstResetCoord
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
/* typename SrcElementwiseOperation, */
B1ElementwiseOperation
,
/* typename DstElementwiseOperation, */
tensor_operation
::
element_wise
::
PassThrough
,
/* InMemoryDataOperationEnum DstInMemOp, */
InMemoryDataOperationEnum
::
Set
,
/* typename BlockSliceLengths, */
Sequence
<
BL0
,
NPerBlock
,
BL1
>
,
/* typename ThreadClusterLengths, */
B1BlockTransferThreadClusterLengths_L0_N_L1
,
/* typename ThreadClusterArrangeOrder, */
B1BlockTransferThreadClusterArrangeOrder
,
/* typename SrcData,
*/
FloatB1
,
/* typename DstData,
*/
FloatB1
,
/* typename SrcDesc,
*/
decltype
(
b1_grid_desc_l0_n_l1
),
/* typename DstDesc,
*/
decltype
(
b1_block_desc_l0perblock_nperblock_l1
),
/* typename SrcDimAccessOrder, */
B1BlockTransferSrcAccessOrder
,
/* typename DstDimAccessOrder, */
Sequence
<
1
,
0
,
2
>
,
/* index_t SrcVectorDim,
*/
B1BlockTransferSrcVectorDim
,
/* index_t DstVectorDim,
*/
2
,
/* index_t SrcScalarPerVector, */
B1BlockTransferSrcScalarPerVector
,
/* index_t DstScalarPerVector, */
B1BlockTransferDstScalarPerVector_L1
,
/* index_t SrcScalarStrideInVector, */
1
,
/* index_t DstScalarStrideInVector, */
1
,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */
B1ThreadTransferSrcResetCoordinateAfterRun
,
/* bool ThreadTransferDstResetCoordinateAfterRun, */
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
b1_grid_desc_l0_n_l1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
...
...
@@ -793,7 +783,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
NPerWmma
,
MRepeat
,
NRepeat
,
KPack
>
{
make_tuple
(
0
,
0
,
0
,
0
,
0
)};
KPack
,
true
>
{
make_tuple
(
0
,
0
,
0
,
0
,
0
)};
auto
acc1_thread_buf
=
blockwise_gemm1
.
GetCThreadBuffer
();
...
...
@@ -899,12 +890,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// softmax
SoftmaxBuf
&
max
=
blockwise_softmax
.
max_value_buf
;
SoftmaxBuf
&
sum
=
blockwise_softmax
.
sum_value_buf
;
// printf("GPU Gemm 0, Tid %03d, GPU acc0 = %lf\n", get_thread_local_1d_id(), acc0_thread_buf[I0]);
// static_for<0, acc0_thread_buf.Size(), 1>{}([&](auto i) {
// printf("GPU Gemm0, Tid %03d, GPU acc%d = %lf\n", get_thread_local_1d_id(), i.value, acc0_thread_buf[i]);
// });
blockwise_softmax
.
Run
(
acc0_thread_buf
,
workspace_buf
);
// printf("GPU SoftMax, Tid %03d, GPU acc0 = %lf\n", get_thread_local_1d_id(), acc0_thread_buf[I0]);
// TODO: may convert to log domain
running_max_new
=
mathext
::
max
(
max
,
running_max
);
...
...
@@ -949,8 +936,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
block_sync_lds
();
// printf("GPU permute lanex, Tid %03d, GPU 0 = %04x\n", get_thread_local_1d_id(), *(reinterpret_cast<const uint16_t*>(&a1_thread_buf[I0])));
blockwise_gemm1
.
Run
(
a1_thread_buf
,
b1_block_buf
,
acc1_thread_buf
);
block_sync_lds
();
...
...
@@ -1022,18 +1007,18 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/*******************************************************************************/
// write out to C, implement shuffle
{
constexpr
auto
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_n
threadper
subgroup_
m
accvgprs
=
blockwise_gemm1
.
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_N
ThreadPer
SubGroup_
M
AccVgprs
();
constexpr
auto
c_thread_desc_mrepeat_mwave_m
threadper
subgroup_nrepeat_nwave_nsubgroup_
n
accvgprs
=
blockwise_gemm1
.
GetCThreadDescriptor_MRepeat_MWave_M
ThreadPer
SubGroup_NRepeat_NWave_NSubGroup_
N
AccVgprs
();
// This API Provide All dimension (size) you need
constexpr
auto
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_n
threadper
subgroup_
m
accvgprs_tmp
=
blockwise_gemm1
.
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_N
ThreadPer
SubGroup_
M
AccVgprs
();
constexpr
auto
c_block_desc_mrepeat_mwave_m
threadper
subgroup_nrepeat_nwave_nsubgroup_
n
accvgprs_tmp
=
blockwise_gemm1
.
GetCBlockDescriptor_MRepeat_MWave_M
ThreadPer
SubGroup_NRepeat_NWave_NSubGroup_
N
AccVgprs
();
constexpr
auto
MWave
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_n
threadper
subgroup_
m
accvgprs_tmp
.
GetLength
(
I1
);
constexpr
auto
M
SubGroup
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_n
threadper
subgroup_
m
accvgprs_tmp
.
GetLength
(
I2
);
constexpr
auto
NWave
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_n
threadper
subgroup_
m
accvgprs_tmp
.
GetLength
(
I4
);
constexpr
auto
N
ThreadPerSubGroup
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_n
threadper
subgroup_
m
accvgprs_tmp
.
GetLength
(
I5
);
constexpr
auto
M
AccVgprs
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_n
threadper
subgroup_
m
accvgprs_tmp
.
GetLength
(
I6
);
constexpr
auto
MWave
=
c_block_desc_mrepeat_mwave_m
threadper
subgroup_nrepeat_nwave_nsubgroup_
n
accvgprs_tmp
.
GetLength
(
I1
);
constexpr
auto
M
ThreadPerSubGroup
=
c_block_desc_mrepeat_mwave_m
threadper
subgroup_nrepeat_nwave_nsubgroup_
n
accvgprs_tmp
.
GetLength
(
I2
);
constexpr
auto
NWave
=
c_block_desc_mrepeat_mwave_m
threadper
subgroup_nrepeat_nwave_nsubgroup_
n
accvgprs_tmp
.
GetLength
(
I4
);
constexpr
auto
N
SubGroup
=
c_block_desc_mrepeat_mwave_m
threadper
subgroup_nrepeat_nwave_nsubgroup_
n
accvgprs_tmp
.
GetLength
(
I5
);
constexpr
auto
N
AccVgprs
=
c_block_desc_mrepeat_mwave_m
threadper
subgroup_nrepeat_nwave_nsubgroup_
n
accvgprs_tmp
.
GetLength
(
I6
);
// LDS descriptor, shuffle and write out in MRepeat x NRepeat times
constexpr
auto
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
=
...
...
@@ -1043,22 +1028,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
static_cast
<
FloatCShuffle
*>
(
p_shared
),
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_n
threadper
subgroup_
m
accvgprs
=
transform_tensor_descriptor
(
constexpr
auto
c_block_desc_mrepeat_mwave_m
threadper
subgroup_nrepeat_nwave_nsubgroup_
n
accvgprs
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMRepeatPerShuffle
>
{},
// MRepeat per shuffle repeat
MWave
,
// MWave
M
SubGroup
,
// M
SubGroup * MAccVgprs
= MPerWmma
MAccVgprs
)),
M
ThreadPerSubGroup
// M
ThreadPerSubGroup
= MPerWmma
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNRepeatPerShuffle
>
{},
// NRepeat per shuffle repeat
NWave
,
// NWave
NThreadPerSubGroup
))),
// NThreadPerSubGroup = NPerWmma
NSubGroup
,
NAccVgprs
))),
// NSubGroup * NAccVgprs = NPerWmma
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
1
,
2
,
6
>
{},
Sequence
<>
{},
Sequence
<
3
,
4
,
5
>
{}));
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
1
,
2
>
{},
Sequence
<>
{},
Sequence
<
3
,
4
,
5
,
6
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
...
...
@@ -1067,30 +1053,30 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_mrepeat_mwave_m
subgroup_maccvgprs
_adaptor
=
const
auto
m_thread_data_on_block_to_mrepeat_mwave_m
threadpersubgroup
_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MRepeat
,
MWave
,
M
SubGroup
,
MAccVgprs
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
make_merge_transform
(
make_tuple
(
MRepeat
,
MWave
,
M
ThreadPerSubGroup
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_to_nrepeat_nwave_n
threadpersubgroup
_adaptor
=
const
auto
n_thread_data_on_block_to_nrepeat_nwave_n
subgroup_naccvgprs
_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
NRepeat
,
NWave
,
N
ThreadPerSubGroup
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
make_merge_transform
(
make_tuple
(
NRepeat
,
NWave
,
N
SubGroup
,
NAccVgprs
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_mrepeat_mwave_m
subgroup_maccvgprs
_adaptor
.
CalculateBottomIndex
(
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_mrepeat_mwave_m
threadpersubgroup
_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_nrepeat_nwave_n
threadpersubgroup
_adaptor
.
CalculateBottomIndex
(
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_nrepeat_nwave_n
subgroup_naccvgprs
_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc1
,
FloatCShuffle
,
decltype
(
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_n
threadper
subgroup_
m
accvgprs
),
decltype
(
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_n
threadper
subgroup_
m
accvgprs
),
decltype
(
c_thread_desc_mrepeat_mwave_m
threadper
subgroup_nrepeat_nwave_nsubgroup_
n
accvgprs
),
decltype
(
c_block_desc_mrepeat_mwave_m
threadper
subgroup_nrepeat_nwave_nsubgroup_
n
accvgprs
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMRepeatPerShuffle
,
I1
,
...
...
@@ -1098,21 +1084,21 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
CShuffleNRepeatPerShuffle
,
I1
,
I1
,
M
AccVgprs
>
,
N
AccVgprs
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
6
,
1
,
// vector write pixel
8
,
// vector write pixel
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_n
threadper
subgroup_
m
accvgprs
,
c_block_desc_mrepeat_mwave_m
threadper
subgroup_nrepeat_nwave_nsubgroup_
n
accvgprs
,
make_multi_index
(
0
,
m_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
0
,
n_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I2
],
m
_thread_data_on_block_idx
[
I3
]),
n
_thread_data_on_block_idx
[
I3
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// shuffle: blockwise copy C from LDS to global
...
...
@@ -1144,7 +1130,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// space filling curve for local reg & global memory
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MRepeat
,
1
,
1
,
NRepeat
,
1
,
1
,
M
AccVgprs
>
,
SpaceFillingCurve
<
Sequence
<
MRepeat
,
1
,
1
,
NRepeat
,
1
,
1
,
N
AccVgprs
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
Sequence
<
CShuffleMRepeatPerShuffle
,
1
,
...
...
@@ -1152,7 +1138,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
CShuffleNRepeatPerShuffle
,
1
,
1
,
M
AccVgprs
>>
{};
N
AccVgprs
>>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_c_global
=
...
...
@@ -1172,10 +1158,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_n
threadper
subgroup_
m
accvgprs
,
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_mrepeat_mwave_m
threadper
subgroup_nrepeat_nwave_nsubgroup_
n
accvgprs
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_n
threadper
subgroup_
m
accvgprs
,
c_block_desc_mrepeat_mwave_m
threadper
subgroup_nrepeat_nwave_nsubgroup_
n
accvgprs
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
4ddda63b
...
...
@@ -1394,19 +1394,24 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
// apply element-wise operation
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
if
(
get_thread_local_1d_id
()
%
32
>
16
){
if
(
get_thread_local_1d_id
()
%
32
<
16
){
// apply type convert
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
v
);
dst_buf
(
Number
<
dst_offset
+
DstScalarPerVector
>
{})
=
__builtin_amdgcn_permlanex16
(
type_convert
<
DstData
>
(
dst_buf
(
Number
<
dst_offset
+
DstScalarPerVector
>
{})),
type_convert
<
DstData
>
(
v
),
LowEightRowlaneIdx
,
HighEightRowLaneIdx
,
1
,
0
);
}
else
{
// apply type convert
dst_buf
(
Number
<
dst_offset
+
DstScalarPerVector
>
{})
=
type_convert
<
DstData
>
(
v
);
dst_buf
(
Number
<
dst_offset
>
{})
=
__builtin_amdgcn_permlanex16
(
type_convert
<
DstData
>
(
dst_buf
(
Number
<
dst_offset
>
{})),
type_convert
<
DstData
>
(
v
),
}
SrcData
d
=
0
;
int
temp
=
0
;
temp
=
__builtin_amdgcn_permlanex16
(
temp
,
type_convert
<
int
>
(
v
),
LowEightRowlaneIdx
,
HighEightRowLaneIdx
,
1
,
0
);
d
=
type_convert
<
float
>
(
temp
);
if
(
get_thread_local_1d_id
()
%
32
<
16
){
dst_buf
(
Number
<
dst_offset
+
DstScalarPerVector
>
{})
=
type_convert
<
DstData
>
(
d
);
}
else
{
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
d
);
}
});
});
...
...
include/ck/utility/data_type.hpp
View file @
4ddda63b
...
...
@@ -964,6 +964,32 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t
return
u
.
fp32
;
}
template
<
>
inline
__host__
__device__
constexpr
int
type_convert
<
int
,
float
>
(
float
x
)
{
union
{
float
fp32
;
int
int32
;
}
u
=
{
x
};
// u.fp32 = x;
return
u
.
int32
;
}
template
<
>
inline
__host__
__device__
constexpr
float
type_convert
<
float
,
int
>
(
int
x
)
{
union
{
int
int32
;
float
fp32
;
}
u
=
{
x
};
// u.fp32 = x;
return
u
.
fp32
;
}
// convert fp32 to bfp16
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
float
>
(
float
x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment