Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
40016f20
Commit
40016f20
authored
May 18, 2021
by
Jing Zhang
Browse files
add m/n repeats
parent
8c84c0b1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
196 additions
and
110 deletions
+196
-110
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+82
-29
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
...include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
+80
-73
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+28
-0
composable_kernel/include/utility/amd_xdlops.hpp
composable_kernel/include/utility/amd_xdlops.hpp
+2
-4
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+4
-4
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
40016f20
...
...
@@ -30,10 +30,20 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static
constexpr
index_t
WaveSize
=
64
;
static
constexpr
index_t
MPerBlock
=
ABlockDesc
{}.
GetLength
(
I1
);
// A is transposed
static
constexpr
index_t
NPerBlock
=
BBlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
MWaves
=
MPerBlock
/
MPerWave
;
static
constexpr
index_t
NWaves
=
NPerBlock
/
NPerWave
;
static
constexpr
index_t
M0
=
ABlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
M1
=
ABlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
N0
=
BBlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
N1
=
BBlockDesc
{}.
GetLength
(
I2
);
// static constexpr index_t MPerBlock = M0 * M1; // A is transposed
// static constexpr index_t NPerBlock = N0 * N1;
static
constexpr
index_t
MWaves
=
M1
/
MPerWave
;
static
constexpr
index_t
NWaves
=
N1
/
NPerWave
;
static
constexpr
index_t
MRepeat
=
M0
;
static
constexpr
index_t
NRepeat
=
N0
;
__device__
constexpr
auto
GetOutputLayout
()
const
{
return
xdlops_gemm
.
GetOutputLayout
();
}
...
...
@@ -59,13 +69,13 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
{
const
index_t
m_offset
=
waveId_m
*
MPerWave
+
xdlops_gemm
.
GetBlkTd
(
laneId
);
const
index_t
k_offset
=
xdlops_gemm
.
GetBlkId
(
laneId
)
*
xdlops_gemm
.
mfma_type
.
k_base
;
return
make_tuple
(
k_offset
,
m_offset
);
return
make_tuple
(
k_offset
,
0
,
m_offset
);
}
else
{
const
index_t
m_offset
=
waveId_m
*
MPerWave
+
laneId
;
const
index_t
k_offset
=
0
;
return
make_tuple
(
k_offset
,
m_offset
);
return
make_tuple
(
k_offset
,
0
,
m_offset
);
}
}
...
...
@@ -81,26 +91,30 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
{
const
index_t
n_offset
=
waveId_n
*
NPerWave
+
xdlops_gemm
.
GetBlkTd
(
laneId
);
const
index_t
k_offset
=
xdlops_gemm
.
GetBlkId
(
laneId
)
*
xdlops_gemm
.
mfma_type
.
k_base
;
return
make_tuple
(
k_offset
,
n_offset
);
return
make_tuple
(
k_offset
,
0
,
n_offset
);
}
else
{
const
index_t
n_offset
=
waveId_n
*
NPerWave
+
laneId
;
const
index_t
k_offset
=
0
;
return
make_tuple
(
k_offset
,
n_offset
);
return
make_tuple
(
k_offset
,
0
,
n_offset
);
}
}
template
<
index_t
AStride
=
MPerWave
,
index_t
BStride
=
NPerWave
>
__device__
static
CIndex
CalculateCThreadOriginDataIndex
(
index_t
blk_i
)
__device__
static
CIndex
CalculateCThreadOriginDataIndex
(
const
index_t
m_repeat_id
,
const
index_t
n_repeat_id
,
const
index_t
blk_i
)
{
const
index_t
waveId
=
get_thread_local_1d_id
()
/
WaveSize
;
const
auto
thread_mtx_on_blk
=
xdlops_gemm
.
GetBeginOfThreadBlk
(
blk_i
);
const
index_t
row
=
(
waveId
/
NWaves
)
*
AStride
+
thread_mtx_on_blk
.
row
;
const
index_t
col
=
(
waveId
%
NWaves
)
*
BStride
+
thread_mtx_on_blk
.
col
;
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
const
index_t
row
=
m_repeat_id
*
M1
+
waveId_m
*
MPerWave
+
thread_mtx_on_blk
.
row
;
const
index_t
col
=
n_repeat_id
*
N1
+
waveId_n
*
NPerWave
+
thread_mtx_on_blk
.
col
;
return
CIndex
{
row
,
col
};
}
...
...
@@ -115,8 +129,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static_assert
(
ABlockDesc
{}.
GetLength
(
I0
)
==
BBlockDesc
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent"
);
static_assert
(
MPerWave
*
MWaves
==
MPerBlock
,
"GemmMWaves * MPerWave != M"
);
static_assert
(
NPerWave
*
NWaves
==
NPerBlock
,
"GemmNWaves * NPerWave != N"
);
//
static_assert(MPerWave * MWaves == MPerBlock, "GemmMWaves * MPerWave != M");
//
static_assert(NPerWave * NWaves == NPerBlock, "GemmNWaves * NPerWave != N");
static_assert
(
BlockSize
==
MWaves
*
NWaves
*
WaveSize
,
"BlockSize != MWaves * NWaves * WaveSize
\n
"
);
...
...
@@ -136,39 +150,78 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static_for
<
0
,
KPerBlock
,
KPerWave
>
{}([
&
](
auto
k
)
{
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
k
,
I0
),
make_tuple
(
k
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
),
a_thread_buf
);
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
k
,
I0
),
make_tuple
(
k
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
);
xdlops_gemm
.
template
Run2
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
0
,
0
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
k
,
I1
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
);
xdlops_gemm
.
template
Run
(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
xdlops_gemm
.
template
Run2
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
0
,
1
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
k
,
I1
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
a_thread_buf
);
xdlops_gemm
.
template
Run2
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
1
,
0
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
xdlops_gemm
.
template
Run2
<
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
c_thread_desc_
),
1
,
1
>(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
});
}
private:
// A[K, M]
static
constexpr
auto
a_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerWave
>
{},
Number
<
1
>
{}
));
static
constexpr
auto
a_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerWave
>
{},
Number
<
MRepeat
>
{},
I1
));
// B[K, N]
static
constexpr
auto
b_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerWave
>
{},
Number
<
1
>
{}));
static
constexpr
auto
b_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerWave
>
{},
Number
<
NRepeat
>
{},
I1
));
static
constexpr
auto
c_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{}));
using
AThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
ABlockDesc
,
decltype
(
a_thread_desc_
),
Sequence
<
KPerWave
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
Sequence
<
KPerWave
,
1
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
1
,
1
>
;
...
...
@@ -176,9 +229,9 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
FloatB
,
BBlockDesc
,
decltype
(
b_thread_desc_
),
Sequence
<
KPerWave
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
Sequence
<
KPerWave
,
1
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
1
,
1
>
;
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
View file @
40016f20
...
...
@@ -278,41 +278,42 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
static_assert
(
MPerBlock
%
MPerWave
==
0
&&
NPerBlock
%
NPerWave
==
0
,
"wrong!"
)
;
// constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
// constexpr index_t NRepeat = N
PerBlock
/
(
N
Per
Thread * NLevel0Cluster * NLevel1Cluster);
// constexpr auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor(
// a_k_m_block_desc,
// make_tuple
(
// make_pass_through_transform(Number<KPerBlock>{})
,
//
make_
unmerge_transform(make_tuple(
// Number<MRepeat>{}, Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{}))),
// make_tuple(Sequence<0>{}, Sequence<1
>{}),
//
make_tuple(Sequence<0>{}, Sequence<1
, 2
>{})
);
// constexpr auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor(
//
b_k_n_block_desc
,
// make_tuple(
//
make_pass_through_transform(Number<KPerBlock>{}),
//
make_unmerge_transform(
make_tuple(
// Number<NRepeat>{}, Number<NPerThread * NLevel0Cluster * NLevel1Cluster
>{}))),
//
make_tuple(Sequence<0>{}, Sequence<1>{}),
//
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
static_assert
(
M
PerBlock
%
(
M
Per
Wave
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerWave
*
NRepeat
)
==
0
,
"wrong!"
);
constexpr
auto
a_k_m0_m1_block_desc
=
transform_dynamic_tensor_descriptor
(
a_k_m_block_desc
,
make_
tuple
(
make_pass_through_transform
(
Number
<
KPerBlock
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MPerBlock
/
MRepeat
>
{})
))
,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{})
,
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
constexpr
auto
b_k_n
0_n1
_block_desc
=
transform_dynamic_tensor_descriptor
(
b_k_n_block_desc
,
make_tuple
(
make_pass_through_transform
(
Number
<
KPerBlock
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NPerBlock
/
NRepeat
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
// constexpr auto c_m0_m1_n0_n1_thread_desc =
// make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
// Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
const
auto
blockwise_gemm
=
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
<
BlockSize
,
FloatAB
,
FloatAB
,
decltype
(
a_k_m_block_desc
),
decltype
(
b_k_n_block_desc
),
MPerWave
,
NPerWave
,
KPerWave
>
{};
const
auto
blockwise_gemm
=
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
<
BlockSize
,
FloatAB
,
FloatAB
,
decltype
(
a_k_m0_m1_block_desc
),
decltype
(
b_k_n0_n1_block_desc
),
MPerWave
,
NPerWave
,
KPerWave
>
{};
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
...
...
@@ -483,50 +484,56 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// static_assert(BlkSize == 16 && NumBlks == 4, "");
// force unrolling the output loop to get ride of scratches
static_for
<
0
,
NumBlks
,
1
>
{}([
&
](
auto
i
)
{
StaticBuffer
<
AddressSpace
::
Vgpr
,
float
,
BlkSize
>
c_thread_buf_
;
static_for
<
0
,
BlkSize
,
1
>
{}([
&
](
auto
j
)
{
c_thread_buf_
(
j
)
=
c_thread_buf
.
template
AsType
<
float
>()[
Number
<
i
*
BlkSize
+
j
>
{}];
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m_i
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n_i
)
{
// force unrolling the output loop to get ride of scratches
static_for
<
0
,
NumBlks
,
1
>
{}([
&
](
auto
i
)
{
StaticBuffer
<
AddressSpace
::
Vgpr
,
float
,
BlkSize
>
c_thread_buf_
;
static_for
<
0
,
BlkSize
,
1
>
{}([
&
](
auto
j
)
{
c_thread_buf_
(
j
)
=
c_thread_buf
.
template
AsType
<
float
>()[
Number
<
m_i
*
(
NRepeat
*
BlkSize
*
NumBlks
)
+
n_i
*
(
BlkSize
*
NumBlks
)
+
i
*
BlkSize
+
j
>
{}];
});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
m_i
,
n_i
,
i
);
const
index_t
k_thread_data_on_global
=
m_block_data_idx_on_global
+
c_thread_mtx_on_block
[
I0
];
const
index_t
b_thread_data_on_global
=
n_block_data_idx_on_global
+
c_thread_mtx_on_block
[
I1
];
constexpr
auto
c_m0_m1_n0_n1_global_tensor_iterator_hacks
=
CGlobalIteratorHacks
{};
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_m0_m1_m2_n_thread_desc
),
decltype
(
c_m0_m1_m2_n_global_desc
),
Sequence
<
M0
,
1
,
M2
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
// CThreadTransferSrcDstAccessOrder,
3
,
// CThreadTransferSrcDstVectorDim,
1
,
// CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_m0_m1_m2_n_global_desc
,
make_multi_index
(
k_thread_data_on_global
/
(
M2
*
M1
),
k_thread_data_on_global
%
(
M2
*
M1
)
/
M2
,
k_thread_data_on_global
%
M2
,
b_thread_data_on_global
)}
.
Run
(
c_m0_m1_m2_n_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf_
,
c_m0_m1_m2_n_global_desc
,
c_global_buf
,
c_m0_m1_n0_n1_global_tensor_iterator_hacks
);
});
});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
i
);
const
index_t
k_thread_data_on_global
=
m_block_data_idx_on_global
+
c_thread_mtx_on_block
[
I0
];
const
index_t
b_thread_data_on_global
=
n_block_data_idx_on_global
+
c_thread_mtx_on_block
[
I1
];
constexpr
auto
c_m0_m1_n0_n1_global_tensor_iterator_hacks
=
CGlobalIteratorHacks
{};
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_m0_m1_m2_n_thread_desc
),
decltype
(
c_m0_m1_m2_n_global_desc
),
Sequence
<
M0
,
1
,
M2
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
// CThreadTransferSrcDstAccessOrder,
3
,
// CThreadTransferSrcDstVectorDim,
1
,
// CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_m0_m1_m2_n_global_desc
,
make_multi_index
(
k_thread_data_on_global
/
(
M2
*
M1
),
k_thread_data_on_global
%
(
M2
*
M1
)
/
M2
,
k_thread_data_on_global
%
M2
,
b_thread_data_on_global
)}
.
Run
(
c_m0_m1_m2_n_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf_
,
c_m0_m1_m2_n_global_desc
,
c_global_buf
,
c_m0_m1_n0_n1_global_tensor_iterator_hacks
);
});
}
}
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
40016f20
...
...
@@ -599,6 +599,34 @@ struct XdlopsGemm
});
}
template
<
class
ADesc
,
class
BDesc
,
class
CDesc
,
index_t
m0
,
index_t
n0
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run2
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
{
static_assert
(
is_same
<
data_type
,
float
>::
value
||
is_same
<
data_type
,
half_t
>::
value
||
is_same
<
data_type
,
ushort
>::
value
,
"base data_type must be float, half, ushort!"
);
static_assert
(
KPerWave
%
KPerXdlops
==
0
,
"KPerWave cannot be divided by KPerXdlops"
);
static_for
<
0
,
KPerWave
,
KPerXdlops
>
{}([
&
](
auto
k
)
{
constexpr
index_t
a_offset
=
ADesc
{}.
CalculateOffset
(
make_multi_index
(
k
,
m0
,
0
));
constexpr
index_t
b_offset
=
BDesc
{}.
CalculateOffset
(
make_multi_index
(
k
,
n0
,
0
));
constexpr
index_t
c_offset
=
CDesc
{}.
CalculateOffset
(
make_multi_index
(
m0
,
n0
));
mfma_type
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
Number
<
a_offset
>
{}],
p_b_wave
[
Number
<
b_offset
>
{}],
p_c_thread
.
template
AsType
<
float16_t
>()(
Number
<
c_offset
>
{}));
});
}
__device__
static
MatrixIndex
GetBeginOfThreadBlk
(
index_t
i
)
{
const
index_t
xdlops_i
=
i
/
GetNumBlksPerXdlops
();
...
...
composable_kernel/include/utility/amd_xdlops.hpp
View file @
40016f20
...
...
@@ -278,11 +278,9 @@ struct intrin_mfma_f32_32x32x2f32;
template
<
>
struct
intrin_mfma_f32_32x32x2f32
<
32
,
32
>
{
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
vector_type
<
float
,
16
>&
reg_c
)
__device__
static
void
Run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float16_t
&
reg_c
)
{
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x2f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
reg_c
=
llvm_intrin_amdgcn_mfma_f32_32x32x2f32
(
reg_a
,
reg_b
,
reg_c
,
0
,
0
,
0
);
}
};
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
40016f20
...
...
@@ -104,21 +104,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
#else
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
32
;
constexpr
index_t
GemmNPerBlock
=
32
;
constexpr
index_t
GemmMPerBlock
=
64
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmKPerWave
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
1
>
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
32
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
1
>
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
32
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment