Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0b547a33
Unverified
Commit
0b547a33
authored
Jun 29, 2022
by
Raman R jana
Committed by
GitHub
Jun 29, 2022
Browse files
Merge pull request #309 from ramjana/wavelet_model
fixed clang format errors
parents
599497b0
702c3379
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
432 additions
and
448 deletions
+432
-448
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+1
-1
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp
...tion/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp
+8
-9
include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp
.../tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp
+54
-61
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp
...tion/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp
+367
-374
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
+1
-2
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
0b547a33
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp
View file @
0b547a33
...
@@ -57,7 +57,7 @@ struct ThreadGroupTensorSliceTransfer_v6r1
...
@@ -57,7 +57,7 @@ struct ThreadGroupTensorSliceTransfer_v6r1
is_same
<
SliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
is_same
<
SliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
"wrong! threads should be mapped to cover entire slicing window"
);
//static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
//
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
// "wrong! ThreadGroup::GetNumOfThread() too small");
// "wrong! ThreadGroup::GetNumOfThread() too small");
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp
View file @
0b547a33
...
@@ -10,7 +10,6 @@
...
@@ -10,7 +10,6 @@
#include "gridwise_gemm_xdl_waveletmodel_cshuffle.hpp"
#include "gridwise_gemm_xdl_waveletmodel_cshuffle.hpp"
#include "gemm_specialization.hpp"
#include "gemm_specialization.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
...
@@ -438,7 +437,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle
...
@@ -438,7 +437,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle
{
{
using
Argument
=
DeviceOp
::
Argument
;
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
#if 0
#if 0
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp
View file @
0b547a33
...
@@ -7,8 +7,8 @@ namespace ck {
...
@@ -7,8 +7,8 @@ namespace ck {
template
<
typename
TileLoadThreadGroup
,
index_t
NumGemmKPrefetchStage
>
template
<
typename
TileLoadThreadGroup
,
index_t
NumGemmKPrefetchStage
>
struct
GridwiseGemmLoadWave
;
struct
GridwiseGemmLoadWave
;
//1-stage prefetch
//
1-stage prefetch
template
<
typename
TileLoadThreadGroup
>
template
<
typename
TileLoadThreadGroup
>
struct
GridwiseGemmLoadWave
<
TileLoadThreadGroup
,
1
>
struct
GridwiseGemmLoadWave
<
TileLoadThreadGroup
,
1
>
{
{
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
...
@@ -53,21 +53,21 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
...
@@ -53,21 +53,21 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
//move to 1
//
move to 1
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
//LDS write 0
//
LDS write 0
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
if
constexpr
(
HasMainLoop
)
if
constexpr
(
HasMainLoop
)
{
{
index_t
i
=
0
;
index_t
i
=
0
;
do
do
{
{
//sync for Load threads()
//
sync for Load threads()
block_sync_lds
();
block_sync_lds
();
// global read i + 1
// global read i + 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
...
@@ -81,11 +81,10 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
...
@@ -81,11 +81,10 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
// sync with math threads()
// sync with math threads()
block_sync_lds
();
block_sync_lds
();
//LDS write i+1
//
LDS write i+1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
++
i
;
++
i
;
}
while
(
i
<
(
num_loop
-
1
));
}
while
(
i
<
(
num_loop
-
1
));
}
}
...
@@ -95,9 +94,7 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
...
@@ -95,9 +94,7 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
block_sync_lds
();
block_sync_lds
();
// GEMM num_loop
// GEMM num_loop
}
}
}
}
};
};
...
@@ -108,10 +105,7 @@ template <typename TileMathThreadGroup>
...
@@ -108,10 +105,7 @@ template <typename TileMathThreadGroup>
struct
GridwiseGemmMathWave
<
TileMathThreadGroup
,
1
>
struct
GridwiseGemmMathWave
<
TileMathThreadGroup
,
1
>
{
{
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
{
return
true
;
}
{
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
{
...
@@ -155,7 +149,6 @@ struct GridwiseGemmMathWave<TileMathThreadGroup, 1>
...
@@ -155,7 +149,6 @@ struct GridwiseGemmMathWave<TileMathThreadGroup, 1>
// GEMM num_loop - 1
// GEMM num_loop - 1
block_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
}
}
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp
View file @
0b547a33
...
@@ -128,42 +128,35 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
...
@@ -128,42 +128,35 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
struct
TileLoadThreadGroup
struct
TileLoadThreadGroup
{
{
__device__
static
constexpr
index_t
GetNumOfThread
()
__device__
static
constexpr
index_t
GetNumOfThread
()
{
return
TileLoadThreadGroupSize
;
}
{
return
TileLoadThreadGroupSize
;
}
__device__
static
constexpr
bool
IsBelong
()
__device__
static
constexpr
bool
IsBelong
()
{
{
return
(
get_thread_local_1d_id
()
>=
TileLoadThreadGroupSize
);
return
(
get_thread_local_1d_id
()
>=
TileLoadThreadGroupSize
);
}
}
__device__
static
index_t
GetThreadId
()
{
return
get_thread_local_1d_id
()
-
TileMathThreadGroupSize
;
}
__device__
static
index_t
GetThreadId
()
{
return
get_thread_local_1d_id
()
-
TileMathThreadGroupSize
;
}
};
};
struct
TileMathThreadGroup
struct
TileMathThreadGroup
{
{
__device__
static
constexpr
index_t
GetNumOfThread
()
__device__
static
constexpr
index_t
GetNumOfThread
()
{
return
TileMathThreadGroupSize
;
}
{
return
TileMathThreadGroupSize
;
}
__device__
static
constexpr
bool
IsBelong
()
__device__
static
constexpr
bool
IsBelong
()
{
{
return
get_thread_local_1d_id
()
<
TileMathThreadGroupSize
;
return
get_thread_local_1d_id
()
<
TileMathThreadGroupSize
;
}
}
__device__
static
index_t
GetThreadId
()
{
return
get_thread_local_1d_id
();
}
__device__
static
index_t
GetThreadId
()
{
return
get_thread_local_1d_id
();
}
};
};
using
CShuffleBlockTransferThreadGroup
=
using
CShuffleBlockTransferThreadGroup
=
ThisThreadBlock
<
TileMathThreadGroupSize
>
;
ThisThreadBlock
<
TileMathThreadGroupSize
>
;
// load and math+store Wave pipelines.
//load and math+store Wave pipelines.
// TODO: build pipelines blocks scheduling parallel tasks
//TODO: build pipelines blocks scheduling parallel tasks
using
GridwiseGemmLoad
=
GridwiseGemmLoadWave
<
TileLoadThreadGroup
,
NumGemmKPrefetchStage
>
;
using
GridwiseGemmLoad
=
GridwiseGemmLoadWave
<
TileLoadThreadGroup
,
NumGemmKPrefetchStage
>
;
using
GridwiseGemmMath
=
GridwiseGemmMathWave
<
TileMathThreadGroup
,
NumGemmKPrefetchStage
>
;
using
GridwiseGemmMath
=
GridwiseGemmMathWave
<
TileMathThreadGroup
,
NumGemmKPrefetchStage
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
{
...
@@ -177,8 +170,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
...
@@ -177,8 +170,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
{
{
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
BK0
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
...
@@ -360,7 +353,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
...
@@ -360,7 +353,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// lds max alignment
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
);
...
@@ -392,10 +384,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
...
@@ -392,10 +384,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
const
index_t
n_block_data_idx_on_grid
=
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
if
(
TileLoadThreadGroup
::
IsBelong
())
if
(
TileLoadThreadGroup
::
IsBelong
())
{
{
//LoadWave
//
LoadWave
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -463,7 +455,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
...
@@ -463,7 +455,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
GridwiseGemmLoad
::
template
RunLoadWavePipeline
<
HasMainKBlockLoop
>(
a_grid_desc_ak0_m_ak1
,
GridwiseGemmLoad
::
template
RunLoadWavePipeline
<
HasMainKBlockLoop
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_blockwise_copy
,
a_grid_buf
,
a_grid_buf
,
...
@@ -477,20 +470,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
...
@@ -477,20 +470,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
b_block_slice_copy_step
,
b_block_slice_copy_step
,
num_k_block_main_loop
);
num_k_block_main_loop
);
block_sync_lds
();
block_sync_lds
();
block_sync_lds
();
block_sync_lds
();
}
}
else
if
(
TileMathThreadGroup
::
IsBelong
())
else
if
(
TileMathThreadGroup
::
IsBelong
())
{
{
//branch early for math wave
// branch early for math wave
constexpr
index_t
KPack
=
math
::
max
(
constexpr
index_t
KPack
=
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
TileMathThreadGroupSize
,
TileMathThreadGroupSize
,
FloatAB
,
FloatAB
,
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
...
@@ -506,11 +498,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
...
@@ -506,11 +498,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// TODO re-architect LDS+math stages
// TODO re-architect LDS+math stages
GridwiseGemmMath
::
template
RunMathWavePipeline
<
HasMainKBlockLoop
>(
a_block_buf
,
GridwiseGemmMath
::
template
RunMathWavePipeline
<
HasMainKBlockLoop
>(
b_block_buf
,
a_block_buf
,
b_block_buf
,
blockwise_gemm
,
c_thread_buf
,
num_k_block_main_loop
);
blockwise_gemm
,
c_thread_buf
,
num_k_block_main_loop
);
// GEMM definition
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// c_mtx += transpose(a_mtx) * b_mtx
...
@@ -570,8 +559,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
...
@@ -570,8 +559,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
N1
,
// N1 = NWave
N1
,
// N1 = NWave
N2
))),
// N2 = NPerXdl
N2
))),
// N2 = NPerXdl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
make_tuple
(
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
// blockwise GEMM c matrix starting index
...
@@ -602,8 +593,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
...
@@ -602,8 +593,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
make_multi_index
(
n_thread_data_on_block
));
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
FloatGemmAcc
,
FloatCShuffle
,
FloatCShuffle
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
...
@@ -621,8 +612,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
...
@@ -621,8 +612,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
1
,
1
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
1
,
true
>
{
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
make_multi_index
(
0
,
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I1
],
...
@@ -683,21 +673,24 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
...
@@ -683,21 +673,24 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
//TODO
// TODO
// 1. we do not need to do LDS swizzle to align global writes writing cache lines
// 1. we do not need to do LDS swizzle to align global writes writing cache
// v_mfma cmat, amat, bmat, cmat - c-mat register layout are 1xN elments (N is vertical or strided dimension)
// lines
// v_mfma cmat, bmat, amat, cmat - c-mat register layout are Mx1 elments (M is coalescing dimension)
// v_mfma cmat, amat, bmat, cmat - c-mat register layout are 1xN
// by enumerating M index in amat, bmat you can align cmat register(s) to contiguous M elements
// elments (N is vertical or strided dimension) v_mfma cmat, bmat, amat,
// for example
// cmat - c-mat register layout are Mx1 elments (M is coalescing
// dimension) by enumerating M index in amat, bmat you can align cmat
// register(s) to contiguous M elements for example
// 1st mfma instruction output space : 0 4 8 12 16 ....
// 1st mfma instruction output space : 0 4 8 12 16 ....
// 2nd mfma instruction output space : 1 5 9 13 17 ....
// 2nd mfma instruction output space : 1 5 9 13 17 ....
// 3rd mfma instruction output space : 2 6 10 14 18 ....
// 3rd mfma instruction output space : 2 6 10 14 18 ....
// 4th mfma instruction output space : 3 7 11 15 19 ....
// 4th mfma instruction output space : 3 7 11 15 19 ....
// you can pack 4 registers output space into 2WORD and do global write (no LDS swizzling required)
// you can pack 4 registers output space into 2WORD and do global write
// 2. avoid using s_barrier in this case where not all 256 threads required to swizzle c layout
// (no LDS swizzling required)
// 2. avoid using s_barrier in this case where not all 256 threads required to
// swizzle c layout
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
// make sure it's safe to write to LDS
...
@@ -732,4 +725,4 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
...
@@ -732,4 +725,4 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
}
}
}
}
};
// GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
};
// GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
}
//namespace ck
}
//
namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
View file @
0b547a33
...
@@ -249,8 +249,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
...
@@ -249,8 +249,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
}();
}();
using
BlockwiseGemm
=
using
BlockwiseGemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment