Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
74f0d5de
Commit
74f0d5de
authored
Feb 14, 2023
by
aska-0096
Browse files
save debugging progress
parent
5df713ef
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
383 additions
and
146 deletions
+383
-146
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
...emm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
+5
-5
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
...tmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
+26
-0
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+3
-3
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+134
-56
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+12
-17
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+139
-50
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+47
-11
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+3
-2
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+2
-2
library/include/ck/library/utility/host_tensor_generator.hpp
library/include/ck/library/utility/host_tensor_generator.hpp
+12
-0
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
View file @
74f0d5de
...
...
@@ -125,12 +125,12 @@ using DeviceGemmInstance =
S
<
4
,
64
,
1
>
,
// B1BlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
1
,
2
,
8
,
8
,
false
,
1
,
// CShuffleM
Xdl
PerWavePerShuffle
2
,
// CShuffleN
Xdl
PerWavePerShuffle
1
,
// CShuffleM
Wmma
PerWavePerShuffle
2
,
// CShuffleN
Wmma
PerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
...
...
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
View file @
74f0d5de
...
...
@@ -117,6 +117,26 @@ int run(int argc, char* argv[])
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
break
;
case
4
:
// A, B0, B1 1
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
case
5
:
// Rand: b1 ; unit: a b0 fail
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
case
6
:
// Rand: b0 ; unit: a b1 fail
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
case
7
:
// Rand: a ; unit: b0 b1 pass
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
default
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
...
...
@@ -220,6 +240,12 @@ int run(int argc, char* argv[])
a_g_m_k
,
b0_g_k_n
,
acc0_g_m_n
,
a_element_op
,
b0_element_op
,
acc0_element_op
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// for(int i =0; i< 128; i++){
// for(int j =0; j< 128; j++){
// printf("%0.2lf ", acc0_g_m_n.mData[i*128 +j]);
// }
// printf("\n");
// }
// masking
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
N
);
...
...
include/ck/host_utility/kernel_launch.hpp
View file @
74f0d5de
...
...
@@ -29,12 +29,12 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
block_dim
.
y
,
block_dim
.
z
);
const
int
nrepeat
=
1
0
;
const
int
nrepeat
=
1
;
printf
(
"Warm up 1 time
\n
"
);
//
printf("Warm up 1 time\n");
// warm up
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
//
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
74f0d5de
...
...
@@ -16,19 +16,36 @@ template <index_t BlockSize,
typename
FloatA
,
typename
FloatB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
BK0NK1BlockDesc
,
typename
ABlockDesc
,
typename
BBlockDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerWMMA
,
index_t
NPerWMMA
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
>
/* A: K0PerBlock x MPerBlock x K1
index_t
KPack
,
bool
TransposeC
=
false
>
/* Option: Read from LDS, big buffer hold all threads required data
* Source
* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
* C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* Destination
* C, non-transpose
* thread level: MRepeat x NRepeat x MAccVgprs
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* KPACK == WMMA_K = 16
*
* Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS)
* Source:
* A(if skip LDS): MRepeat x KPack
* B(if skip LDS): NRepeat x KPack
* Destination
* C, non-transpose
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
*/
struct
BlockwiseGemmWMMA
_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
struct
BlockwiseGemmWMMA
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -42,18 +59,10 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
static
constexpr
index_t
WaveSize
=
32
;
static
constexpr
index_t
MPerBlock
=
AK0MK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
NPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
KPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
)
*
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I4
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I4
);
static
constexpr
index_t
A_K0
=
AK0MK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
B_K0
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
>
{};
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
>
{};
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
...
...
@@ -79,6 +88,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
// Default, Block buffer in LDS, thread level offset enabled
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
...
...
@@ -129,23 +139,63 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
// using Tuple5 = decltype(CalculateAThreadOriginDataIndex());
// __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle(
// Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
// Tuple4 b_origin = CalculateBThreadOriginDataIndex())
// : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
__host__
__device__
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
()
template
<
index_t
m0
,
index_t
n0
>
__device__
static
auto
CalculateCThreadOriginDataIndex7D
(
Number
<
m0
>
,
Number
<
n0
>
)
{
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
BK0NK1BlockDesc
::
IsKnownAtCompileTime
(),
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
wmma_gemm
.
GetBeginOfThreadBlk3D
();
return
make_tuple
(
Number
<
m0
>
{},
blk_idx
[
I0
],
waveId_m
,
Number
<
n0
>
{},
waveId_n
,
blk_idx
[
I1
],
blk_idx
[
I2
]);
}
using
Tuple5
=
decltype
(
CalculateAThreadOriginDataIndex
());
__host__
__device__
BlockwiseGemmWMMA
(
Tuple5
a_origin
=
CalculateAThreadOriginDataIndex
(),
Tuple5
b_origin
=
CalculateBThreadOriginDataIndex
())
:
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
)
{
static_assert
(
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
MPerBlock
%
(
MPerWMMA
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
static_assert
(
MPerBlock
%
(
MPerWMMA
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
"wrong!"
);
// printf("tid %03d, Mat-B offset %d\n", get_thread_local_1d_id()%32, CalculateBThreadOriginDataIndex().At(Number<3>{}));
}
// transposed WMMA output C' = B' * A'
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
()
{
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
// constexpr auto NSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
// constexpr auto MThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr
auto
NAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
return
make_naive_tensor_descriptor_packed
(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
NRepeat
>
{},
I1
,
I1
,
NAccVgprs
));
}
// Thread level, register decriptor. Vector-write
...
...
@@ -171,9 +221,31 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
MAccVgprs
));
}
// Provide dimension size
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerWMMA
),
MWaves
,
MPerWMMA
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerWMMA
),
NWaves
,
NPerWMMA
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
);
}
// transposed WMMA output C' = B' * A'
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
()
{
constexpr
auto
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
...
...
@@ -184,37 +256,31 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
Number
<
NPerWMMA
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_N
ThreadPer
SubGroup_
M
AccVgprs
(
.
MakeCDesc_MBlockxRepeat_MWave_M
ThreadPer
SubGroup_NBlockxRepeat_NWave_NSubGroup_
N
AccVgprs
(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
);
}
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_K0_M0_M1_M2_K1
()
// Provide dimension size
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
{
return
transform_tensor_descriptor
(
AK0MK1BlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
A_K0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
constexpr
auto
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{},
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{}));
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_K0_N0_N1_N2_K1
()
{
return
transform_tensor_descriptor
(
BK0NK1BlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
B_K0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
);
}
// Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static
constexpr
auto
a_block_desc_k0_m0_m1_m2_k1
=
MakeABlockDescriptor_K0_M0_M1_M2_K1
()
;
static
constexpr
auto
b_block_desc_k0_n0_n1_n2_k1
=
MakeBBlockDescriptor_K0_N0_N1_N2_K1
()
;
static
constexpr
ABlockDesc
a_block_desc_k0_m0_m1_m2_k1
;
static
constexpr
BBlockDesc
b_block_desc_k0_n0_n1_n2_k1
;
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
...
...
@@ -235,6 +301,9 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
// static_for<0, a_thread_buf.size(), 1>{}([&](auto i) {
// a_thread_buf(i) = 1;
// });
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
...
...
@@ -254,6 +323,9 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
,
n0
,
0
,
0
,
i
%
B_K1
))
>
{}];
// a_thread_vec.template AsType<FloatA>()(i) = 1;
// b_thread_vec.template AsType<FloatB>()(i) = 1;
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
...
...
@@ -262,6 +334,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
// printf("GPU Gemm0 input, Tid %03d, A%2d = %04x, B%2d = %0x4\n",
// get_thread_local_1d_id(),
// i.value, *(reinterpret_cast<uint16_t*>(&a_thread_vec.template AsType<FloatA>()(i))),
// i.value, *(reinterpret_cast<uint16_t*>(&b_thread_vec.template AsType<FloatB>()(i))));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>()(
Number
<
0
>{}),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>()(
Number
<
0
>
{}),
...
...
@@ -304,10 +382,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
// AThreadCopy a_thread_copy_;
// BThreadCopy b_thread_copy_;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
};
// block wise level pipe designed for inline asm
...
...
@@ -601,7 +677,9 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
// TODO: Fix it, MRepeat < NRepeat
constexpr
auto
RepeatDiff
=
MRepeat
-
NRepeat
;
// Read all Mrepeat, Nrepeat
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
iN
)
{
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
74f0d5de
...
...
@@ -145,11 +145,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B0Spec
,
B1Spec
,
CSpec
>
;
// K1 = Max Vector Access Pixels
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
K0PerBlock
*
K1
};
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
...
...
@@ -167,13 +162,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
Number
<
K1
>
{});
}
static
auto
MakeB1GridDescriptor_BL0_N_BL1
(
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_lengths_vec
,
static
auto
MakeB1GridDescriptor_BL0_N_BL1
(
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_lengths_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_strides_vec
)
{
return
Transform
::
MakeB1GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_ns_ls_lengths_vec
,
b1_gs_ns_ls_strides_vec
),
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_ns_ls_lengths_vec
,
b1_gs_ns_ls_strides_vec
),
Number
<
L1
>
{});
}
...
...
@@ -462,8 +455,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_batched_gemm_softmax_gemm_wmma_cshuffle
<
GridwiseOp
,
ADataType
,
...
...
@@ -482,7 +473,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
ComputeBasePtrOfStridedBatch
,
C0MatrixMask
,
typename
GridwiseOp
::
DefaultBlock2CTileMap
,
has_main_loop
>
;
has_main_
k_block_
loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
...
...
@@ -754,11 +745,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<<
K0PerBlock
<<
", "
<<
K1
<<
", "
<<
MPerBlock
<<
", "
<<
NPerWMMA
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
L0PerBlock
<<
", "
<<
L1
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
"ASpec"
<<
getTensorSpecializationString
(
ASpec
)
<<
", "
<<
"B0Spec"
<<
getTensorSpecializationString
(
B0Spec
)
<<
", "
<<
"B1Spec"
<<
getTensorSpecializationString
(
B1Spec
)
<<
", "
<<
"CSpec"
<<
getTensorSpecializationString
(
CSpec
)
<<
", "
<<
getMaskingSpecializationString
(
MaskingSpec
)
<<
">"
<<
" NumPrefetch: "
<<
NumPrefetch
<<
", "
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
74f0d5de
...
...
@@ -190,22 +190,89 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1Value should be Number<...>
static
constexpr
auto
AK0
=
Number
<
K0PerBlock
>
{};
static
constexpr
auto
AK1
=
Number
<
K1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
K0PerBlock
>
{};
static
constexpr
auto
BK1
=
Number
<
K1Value
>
{};
static
constexpr
auto
L0
=
Number
<
L0PerBlock
>
{};
static
constexpr
auto
L1
=
Number
<
L1Value
>
{};
static
constexpr
auto
Gemm0MWaves
=
MPerBlock
/
(
MPerWmma
*
MRepeat
);
static
constexpr
auto
Gemm0LWaves
=
L0PerBlock
*
L1Value
/
(
LPerWmma
*
LRepeat
);
static
constexpr
auto
AL0
=
Number
<
L0PerBlock
/
2
>
{};
static
constexpr
auto
AL1
=
Number
<
L1Value
>
{};
static
constexpr
auto
BL0
=
Number
<
L0PerBlock
>
{};
static
constexpr
auto
BL1
=
Number
<
L1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
template
<
typename
A0BlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeA0BlockDescriptor_K0_M0_M1_M2_K1
(
const
A0BlockDesc_AK0_M_AK1
&
)
{
constexpr
index_t
A_K0
=
A0BlockDesc_AK0_M_AK1
{}.
GetLength
(
I0
);
constexpr
index_t
A_K1
=
A0BlockDesc_AK0_M_AK1
{}.
GetLength
(
I2
);
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
return
transform_tensor_descriptor
(
A0BlockDesc_AK0_M_AK1
{},
make_tuple
(
make_pass_through_transform
(
Number
<
A_K0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
template
<
typename
B0BlockDesc_BK0_L_BK1
>
__host__
__device__
static
constexpr
auto
MakeB0BlockDescriptor_K0_L0_L1_L2_K1
(
const
B0BlockDesc_BK0_L_BK1
&
)
{
constexpr
index_t
B_K0
=
B0BlockDesc_BK0_L_BK1
{}.
GetLength
(
I0
);
constexpr
index_t
B_K1
=
B0BlockDesc_BK0_L_BK1
{}.
GetLength
(
I2
);
constexpr
index_t
LWaves
=
LPerBlock
/
(
LRepeat
*
LPerWmma
);
return
transform_tensor_descriptor
(
B0BlockDesc_BK0_L_BK1
{},
make_tuple
(
make_pass_through_transform
(
Number
<
B_K0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
LRepeat
>
{},
Number
<
LWaves
>
{},
Number
<
LPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
template
<
typename
A1BlockDesc_AL0_M_AL1
>
__host__
__device__
static
constexpr
auto
MakeA1BlockDescriptor_L0_M0_M1_M2_L1
(
const
A1BlockDesc_AL0_M_AL1
&
)
{
constexpr
index_t
A_L0
=
A1BlockDesc_AL0_M_AL1
{}.
GetLength
(
I0
);
constexpr
index_t
A_L1
=
A1BlockDesc_AL0_M_AL1
{}.
GetLength
(
I2
);
return
transform_tensor_descriptor
(
A1BlockDesc_AL0_M_AL1
{},
make_tuple
(
make_pass_through_transform
(
Number
<
A_L0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
)),
make_pass_through_transform
(
Number
<
A_L1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
template
<
typename
B1BlockDesc_BL0_N_BL1
>
__host__
__device__
static
constexpr
auto
MakeB1BlockDescriptor_L0_N0_N1_N2_L1
(
const
B1BlockDesc_BL0_N_BL1
&
)
{
constexpr
index_t
B_K0
=
B1BlockDesc_BL0_N_BL1
{}.
GetLength
(
I0
);
constexpr
index_t
B_K1
=
B1BlockDesc_BL0_N_BL1
{}.
GetLength
(
I2
);
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
return
transform_tensor_descriptor
(
B1BlockDesc_BL0_N_BL1
{},
make_tuple
(
make_pass_through_transform
(
Number
<
B_K0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
...
...
@@ -226,8 +293,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
{
// B1 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
L0
,
Number
<
NPerBlock
>
{},
L1
),
make_tuple
(
Number
<
NPerBlock
+
B1BlockLdsExtraN
>
{}
*
L1
,
L1
,
I1
));
make_tuple
(
B
L0
,
Number
<
NPerBlock
>
{},
B
L1
),
make_tuple
(
Number
<
NPerBlock
+
B1BlockLdsExtraN
>
{}
*
B
L1
,
B
L1
,
I1
));
}
__host__
__device__
static
constexpr
auto
...
...
@@ -374,7 +441,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
static
constexpr
auto
b1_block_desc_bl0_n_bl1
=
GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1
();
static
constexpr
auto
max_lds_align
=
math
::
lcm
(
math
::
lcm
(
AK1
,
BK1
),
L1
);
static
constexpr
auto
max_lds_align
=
math
::
lcm
(
math
::
lcm
(
AK1
,
BK1
),
B
L1
);
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
...
...
@@ -451,7 +518,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/*******************************************************************************/
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
// constexpr auto max_lds_align = K1Value;
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
b0_block_desc_k0perblock_lperblock_k1
=
GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1
();
...
...
@@ -491,7 +558,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
B0ElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
N
PerBlock
,
BK1
>
,
Sequence
<
BK0
,
L
PerBlock
,
BK1
>
,
B0BlockTransferThreadClusterLengths_K0_L_K1
,
B0BlockTransferThreadClusterArrangeOrder
,
FloatB0
,
...
...
@@ -520,23 +587,27 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
constexpr
auto
WmmaK
=
16
;
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1Value
,
WmmaK
);
auto
blockwise_gemm0
=
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
<
BlockSize
,
auto
blockwise_gemm0
=
BlockwiseGemmWMMA
<
BlockSize
,
FloatA
,
FloatB0
,
FloatAcc0
,
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
decltype
(
b0_block_desc_k0perblock_lperblock_k1
),
decltype
(
MakeA0BlockDescriptor_K0_M0_M1_M2_K1
(
a_block_desc_k0perblock_mperblock_k1
)),
decltype
(
MakeB0BlockDescriptor_K0_L0_L1_L2_K1
(
b0_block_desc_k0perblock_lperblock_k1
)),
MPerBlock
,
LPerBlock
,
K0PerBlock
*
K1Value
,
MPerWmma
,
LPerWmma
,
MRepeat
,
LRepeat
,
KPack
>
{};
KPack
,
true
>
{};
// C' = B' x A'
// Prepare Register for A*B0 matrix
auto
acc0_thread_buf
=
blockwise_gemm0
.
GetCThreadBuffer
();
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
constexpr
auto
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
=
blockwise_gemm0
.
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
();
...
...
@@ -550,7 +621,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
constexpr
auto
acc0_thread_desc_l0perblock_mperblock_l1
=
transform_tensor_descriptor
(
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
lrepeat
,
l
repeat
,
lsubgroup
)),
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
lrepeat
,
l
wave
,
lsubgroup
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
mrepeat
,
mwave
,
mthreadpersubgroup
)),
make_pass_through_transform
(
laccvgprs
)),
make_tuple
(
Sequence
<
3
,
4
,
5
>
{},
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
6
>
{}),
...
...
@@ -587,7 +658,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
constexpr
auto
t_lwave
=
thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
At
(
I4
);
constexpr
auto
t_lsubgroup
=
thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
At
(
I5
);
constexpr
auto
t_laccvgprs
=
thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
At
(
I6
);
if
(
get_thread_local_1d_id
()
==
0
){
printf
(
"t_mrepeat %d, t_mwave %d, t_mthreadpersubgroup %d, t_lrepeat %d, t_lwave %d, t_lsubgroup %d, t_laccvgprs %d
\n
"
,
t_mrepeat
.
value
,
t_mwave
.
value
,
t_mthreadpersubgroup
.
value
,
t_lrepeat
.
value
,
t_lwave
.
value
,
t_lsubgroup
.
value
,
t_laccvgprs
.
value
);
}
// get acc0 thread map
constexpr
auto
m0_l_m1_to_m_l_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
t_mrepeat
*
t_mwave
,
t_mthreadpersubgroup
)),
...
...
@@ -628,11 +708,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/*******************************************************************************/
// B1 matrix in LDS memory, dst of blockwise copy
constexpr
auto
b1_block_desc_l0perblock_nperblock_l1
=
GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1
();
constexpr
auto
b1_block_slice_copy_step
=
make_multi_index
(
L0
PerBlock
,
0
,
0
);
constexpr
auto
b1_block_slice_copy_step
=
make_multi_index
(
B
L0
,
0
,
0
);
// A1 matrix in VGPR
constexpr
auto
A1ThreadSlice_L0PerBlock_MPerBlock_L1
=
make_tuple
(
Number
<
L0
PerBlock
*
L1
Value
/
laccvgprs
>
{},
Number
<
A
L0
*
A
L1
/
laccvgprs
>
{},
Number
<
mrepeat
*
mwave
*
mthreadpersubgroup
>
{},
Number
<
laccvgprs
>
{});
// Data duplicated dimension
...
...
@@ -665,10 +745,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// B1 matrix blockwise copy
auto
b1_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
B
0
ElementwiseOperation
,
B
1
ElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
L0
,
NPerBlock
,
L1
>
,
Sequence
<
B
L0
,
NPerBlock
,
B
L1
>
,
B1BlockTransferThreadClusterLengths_L0_N_L1
,
B1BlockTransferThreadClusterArrangeOrder
,
FloatB1
,
...
...
@@ -700,22 +780,25 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
b1_block_desc_l0perblock_nperblock_l1
.
GetElementSpaceSize
());
auto
blockwise_gemm1
=
BlockwiseGemmWMMA
_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
<
BlockSize
,
BlockwiseGemmWMMA
<
BlockSize
,
FloatA
,
FloatB1
,
FloatAcc1
,
decltype
(
a1_thread_desc_l0perblock_mperblock_l1
),
decltype
(
b1_block_desc_l0perblock_nperblock_l1
),
decltype
(
MakeA1BlockDescriptor_L0_M0_M1_M2_L1
(
a1_thread_desc_l0perblock_mperblock_l1
)),
decltype
(
MakeB1BlockDescriptor_L0_N0_N1_N2_L1
(
b1_block_desc_l0perblock_nperblock_l1
)),
MPerBlock
,
NPerBlock
,
BL0
*
BL1
,
MPerWmma
,
NPerWmma
,
MRepeat
,
NRepeat
,
KPack
>
{};
KPack
>
{
make_tuple
(
0
,
0
,
0
,
0
,
0
)
};
auto
acc1_thread_buf
=
blockwise_gemm1
.
GetCThreadBuffer
();
const
index_t
num_gemm1_l_block_outer_loop
=
b0_grid_desc_k0_l_k1
.
GetLength
(
I1
)
/
LPerBlock
;
constexpr
index_t
num_gemm1_l_block_inner_loop
=
LPerBlock
/
(
L0
PerBlock
*
L1
Value
);
constexpr
index_t
num_gemm1_l_block_inner_loop
=
LPerBlock
/
(
B
L0
*
B
L1
);
// Initialize C
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc1
,
acc1_thread_buf
.
Size
(),
true
>
c_thread_buf
;
...
...
@@ -811,13 +894,17 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
block_sync_lds
();
// gemm0 end
// gemm0 incorrect
// Tiled softmax start
// softmax
SoftmaxBuf
&
max
=
blockwise_softmax
.
max_value_buf
;
SoftmaxBuf
&
sum
=
blockwise_softmax
.
sum_value_buf
;
// printf("GPU Gemm 0, Tid %03d, GPU acc0 = %lf\n", get_thread_local_1d_id(), acc0_thread_buf[I0]);
// static_for<0, acc0_thread_buf.Size(), 1>{}([&](auto i) {
// printf("GPU Gemm0, Tid %03d, GPU acc%d = %lf\n", get_thread_local_1d_id(), i.value, acc0_thread_buf[i]);
// });
blockwise_softmax
.
Run
(
acc0_thread_buf
,
workspace_buf
);
// printf("GPU SoftMax, Tid %03d, GPU acc0 = %lf\n", get_thread_local_1d_id(), acc0_thread_buf[I0]);
// TODO: may convert to log domain
running_max_new
=
mathext
::
max
(
max
,
running_max
);
...
...
@@ -862,6 +949,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
block_sync_lds
();
// printf("GPU permute lanex, Tid %03d, GPU 0 = %04x\n", get_thread_local_1d_id(), *(reinterpret_cast<const uint16_t*>(&a1_thread_buf[I0])));
blockwise_gemm1
.
Run
(
a1_thread_buf
,
b1_block_buf
,
acc1_thread_buf
);
block_sync_lds
();
...
...
@@ -934,11 +1023,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// write out to C, implement shuffle
{
constexpr
auto
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
blockwise_gemm
0
.
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
blockwise_gemm
1
.
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
// This API Provide All dimension (size) you need
constexpr
auto
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
=
blockwise_gemm
0
.
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
blockwise_gemm
1
.
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
constexpr
auto
MWave
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.
GetLength
(
I1
);
constexpr
auto
MSubGroup
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.
GetLength
(
I2
);
...
...
@@ -973,7 +1062,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
0
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
);
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
1
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
74f0d5de
...
...
@@ -140,6 +140,39 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_K0_M0_M1_M2_K1
(
const
ABlockDesc_AK0_M_AK1
&
)
{
constexpr
index_t
A_K0
=
ABlockDesc_AK0_M_AK1
{}.
GetLength
(
I0
);
constexpr
index_t
A_K1
=
ABlockDesc_AK0_M_AK1
{}.
GetLength
(
I2
);
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
return
transform_tensor_descriptor
(
ABlockDesc_AK0_M_AK1
{},
make_tuple
(
make_pass_through_transform
(
Number
<
A_K0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_K0_N0_N1_N2_K1
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
B_K0
=
BBlockDesc_BK0_N_BK1
{}.
GetLength
(
I0
);
constexpr
index_t
B_K1
=
BBlockDesc_BK0_N_BK1
{}.
GetLength
(
I2
);
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
return
transform_tensor_descriptor
(
BBlockDesc_BK0_N_BK1
{},
make_tuple
(
make_pass_through_transform
(
Number
<
B_K0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
...
...
@@ -414,12 +447,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
auto
blockwise_gemm
=
BlockwiseGemmWMMA
_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
<
BlockSize
,
BlockwiseGemmWMMA
<
BlockSize
,
FloatA
,
FloatB
,
FloatAcc
,
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
decltype
(
b_block_desc_k0perblock_nperblock_k1
),
decltype
(
MakeABlockDescriptor_K0_M0_M1_M2_K1
(
a_block_desc_k0perblock_mperblock_k1
)),
decltype
(
MakeBBlockDescriptor_K0_N0_N1_N2_K1
(
b_block_desc_k0perblock_nperblock_k1
)),
MPerBlock
,
NPerBlock
,
K0PerBlock
*
K1
,
MPerWmma
,
NPerWmma
,
MRepeat
,
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
74f0d5de
...
...
@@ -1382,6 +1382,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
// copy data from src_buf into dst_vector
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
// idx_md err. as dst access 2 strided elements while src visit 1 per loop
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
...
...
@@ -1396,13 +1397,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
if
(
get_thread_local_1d_id
()
%
32
>
16
){
// apply type convert
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
v
);
dst_buf
(
Number
<
dst_offset
+
d
st
_buf
.
size
()
/
2
>
{})
=
__builtin_amdgcn_permlanex16
(
type_convert
<
DstData
>
(
dst_buf
(
Number
<
dst_offset
+
d
st
_buf
.
size
()
/
2
>
{})),
dst_buf
(
Number
<
dst_offset
+
D
st
ScalarPerVector
>
{})
=
__builtin_amdgcn_permlanex16
(
type_convert
<
DstData
>
(
dst_buf
(
Number
<
dst_offset
+
D
st
ScalarPerVector
>
{})),
type_convert
<
DstData
>
(
v
),
LowEightRowlaneIdx
,
HighEightRowLaneIdx
,
1
,
0
);
}
else
{
// apply type convert
dst_buf
(
Number
<
dst_offset
+
d
st
_buf
.
size
()
/
2
>
{})
=
type_convert
<
DstData
>
(
v
);
dst_buf
(
Number
<
dst_offset
+
D
st
ScalarPerVector
>
{})
=
type_convert
<
DstData
>
(
v
);
dst_buf
(
Number
<
dst_offset
>
{})
=
__builtin_amdgcn_permlanex16
(
type_convert
<
DstData
>
(
dst_buf
(
Number
<
dst_offset
>
{})),
type_convert
<
DstData
>
(
v
),
LowEightRowlaneIdx
,
HighEightRowLaneIdx
,
1
,
0
);
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
74f0d5de
...
...
@@ -517,12 +517,12 @@ struct WmmaGemm
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
return
GetSwizzledLaneIdLow
();
return
TransposeC
?
GetLaneIdUnderSubGroup
()
:
GetSwizzledLaneIdLow
();
}
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
return
GetLaneIdUnderSubGroup
();
return
TransposeC
?
GetSwizzledLaneIdLow
()
:
GetLaneIdUnderSubGroup
();
}
__device__
static
CIndex
GetBeginOfThreadBlk
()
...
...
library/include/ck/library/utility/host_tensor_generator.hpp
View file @
74f0d5de
...
...
@@ -55,6 +55,18 @@ struct GeneratorTensor_1<int8_t>
}
};
template
<
typename
T
>
struct
GeneratorTensor_dec1
{
T
value
=
0.1
;
template
<
typename
...
Is
>
T
operator
()(
Is
...)
{
return
value
;
}
};
template
<
typename
T
>
struct
GeneratorTensor_2
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment