Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
56de337f
Commit
56de337f
authored
Mar 29, 2024
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
41b920e2
687d2b7e
Changes
161
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2759 additions
and
251 deletions
+2759
-251
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+5
-5
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
...or_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
+9
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+405
-5
include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp
...pu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+545
-183
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+135
-0
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp
.../thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp
+1066
-0
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+85
-35
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp
...tor_transform/transform_contraction_to_gemm_arraybase.hpp
+391
-0
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+2
-1
include/ck/utility/amd_inline_asm.hpp
include/ck/utility/amd_inline_asm.hpp
+6
-18
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+15
-0
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+57
-0
include/ck/wrapper/layout.hpp
include/ck/wrapper/layout.hpp
+9
-0
include/ck/wrapper/operations/copy.hpp
include/ck/wrapper/operations/copy.hpp
+3
-0
include/ck/wrapper/operations/gemm.hpp
include/ck/wrapper/operations/gemm.hpp
+6
-0
include/ck/wrapper/tensor.hpp
include/ck/wrapper/tensor.hpp
+9
-0
include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp
include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp
+3
-0
include/ck/wrapper/utils/kernel_utils.hpp
include/ck/wrapper/utils/kernel_utils.hpp
+3
-0
include/ck/wrapper/utils/layout_utils.hpp
include/ck/wrapper/utils/layout_utils.hpp
+4
-1
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
56de337f
...
@@ -264,7 +264,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -264,7 +264,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2ETileMap
&
block_2_etile_map
)
const
Block2ETileMap
&
)
{
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
...
@@ -310,10 +310,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -310,10 +310,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
}
// check block-to-E-tile
// check block-to-E-tile
if
(
!
block_2_etile_map
.
CheckValidity
(
e_grid_desc_m_n
))
//
if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
{
//
{
return
false
;
//
return false;
}
//
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each
// check tensor size: cannot be larger than 2GB each
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
View file @
56de337f
...
@@ -17,18 +17,21 @@ enum struct PipelineVersion
...
@@ -17,18 +17,21 @@ enum struct PipelineVersion
v2
,
v2
,
// v3 is only used in the Stream-K implementation.
// v3 is only used in the Stream-K implementation.
v4
,
v4
,
weight_only
,
};
};
template
<
PipelineVersion
PipelineVer
,
template
<
PipelineVersion
PipelineVer
,
index_t
NumPrefetch
=
1
,
index_t
NumPrefetch
=
1
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
,
bool
AEnableLds
=
true
,
bool
BEnableLds
=
true
>
constexpr
auto
GridwiseGemmPipeline_Selector
()
constexpr
auto
GridwiseGemmPipeline_Selector
()
{
{
if
constexpr
(
PipelineVer
==
PipelineVersion
::
v1
)
if
constexpr
(
PipelineVer
==
PipelineVersion
::
v1
)
{
{
if
constexpr
(
LoopSched
==
LoopScheduler
::
Default
)
if
constexpr
(
LoopSched
==
LoopScheduler
::
Default
)
{
{
return
GridwiseGemmPipeline_v1
<
NumPrefetch
>
{};
return
GridwiseGemmPipeline_v1
<
NumPrefetch
,
AEnableLds
,
BEnableLds
>
{};
}
}
else
if
constexpr
(
LoopSched
==
LoopScheduler
::
Interwave
)
else
if
constexpr
(
LoopSched
==
LoopScheduler
::
Interwave
)
{
{
...
@@ -43,6 +46,10 @@ constexpr auto GridwiseGemmPipeline_Selector()
...
@@ -43,6 +46,10 @@ constexpr auto GridwiseGemmPipeline_Selector()
{
{
return
GridwiseGemmPipeline_v4
<
NumPrefetch
>
{};
return
GridwiseGemmPipeline_v4
<
NumPrefetch
>
{};
}
}
else
if
constexpr
(
PipelineVer
==
PipelineVersion
::
weight_only
)
{
return
GridwiseGemmPipeline_v1_WeightOnly
<
NumPrefetch
,
AEnableLds
,
BEnableLds
>
{};
}
else
else
{
{
std
::
cerr
<<
"GridwiseGemmPipeline configuration is not available"
<<
std
::
endl
;
std
::
cerr
<<
"GridwiseGemmPipeline configuration is not available"
<<
std
::
endl
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
56de337f
...
@@ -9,12 +9,12 @@
...
@@ -9,12 +9,12 @@
namespace
ck
{
namespace
ck
{
template
<
index_t
NumPrefetch
>
template
<
index_t
NumPrefetch
,
bool
AEnableLds
,
bool
BEnableLds
>
struct
GridwiseGemmPipeline_v1
;
struct
GridwiseGemmPipeline_v1
;
// 1-stage prefetch
// 1-stage prefetch
template
<
>
template
<
>
struct
GridwiseGemmPipeline_v1
<
1
>
struct
GridwiseGemmPipeline_v1
<
1
,
true
,
true
>
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -108,7 +108,7 @@ struct GridwiseGemmPipeline_v1<1>
...
@@ -108,7 +108,7 @@ struct GridwiseGemmPipeline_v1<1>
// 2-stage prefetch
// 2-stage prefetch
template
<
>
template
<
>
struct
GridwiseGemmPipeline_v1
<
2
>
struct
GridwiseGemmPipeline_v1
<
2
,
true
,
true
>
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -254,6 +254,406 @@ struct GridwiseGemmPipeline_v1<2>
...
@@ -254,6 +254,406 @@ struct GridwiseGemmPipeline_v1<2>
}
}
};
};
template
<
>
struct
GridwiseGemmPipeline_v1
<
1
,
false
,
true
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
constexpr
auto
a_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
auto
a_block_buf_switch
=
a_block_buf
;
// preload data into LDS
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_origin_idx
,
a_block_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
a_blockwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_origin_idx
,
a_block_buf_switch
);
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
a_block_buf
=
a_block_buf_switch
;
++
i
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
{
block_sync_lds
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
}
}
};
template
<
>
struct
GridwiseGemmPipeline_v1
<
1
,
true
,
false
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
constexpr
auto
b_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
auto
b_block_buf_switch
=
b_block_buf
;
// preload data into LDS
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_origin_idx
,
b_block_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_origin_idx
,
b_block_buf_switch
);
block_sync_lds
();
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_block_buf
=
b_block_buf_switch
;
++
i
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
{
block_sync_lds
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
}
}
};
template
<
>
struct
GridwiseGemmPipeline_v1
<
1
,
false
,
false
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
constexpr
auto
b_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
constexpr
auto
a_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
auto
b_block_buf_switch
=
b_block_buf
;
auto
a_block_buf_switch
=
a_block_buf
;
// preload data into LDS
a_blockwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_origin_idx
,
a_block_buf
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_origin_idx
,
b_block_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
a_blockwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_origin_idx
,
a_block_buf_switch
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_origin_idx
,
b_block_buf_switch
);
block_sync_lds
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
a_block_buf
=
a_block_buf_switch
;
b_block_buf
=
b_block_buf_switch
;
++
i
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
{
block_sync_lds
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
}
}
};
template
<
index_t
NumPrefetch
,
bool
AEnableLds
,
bool
BEnableLds
>
struct
GridwiseGemmPipeline_v1_WeightOnly
;
template
<
>
struct
GridwiseGemmPipeline_v1_WeightOnly
<
1
,
true
,
true
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
ScaleGridDesc
,
typename
ScaleGridBuffer
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
ScaleGridDesc
&
scale_grid_desc
,
const
ScaleGridBuffer
&
scale_grid_buf
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
// Global Prefetch Stage 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
// Scale read once
b_blockwise_copy
.
RunScaleRead
(
scale_grid_desc
,
scale_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
// Dequantization fused in blockwise_copy
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
++
i
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
{
block_sync_lds
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
};
template
<
index_t
NumPrefetch
>
template
<
index_t
NumPrefetch
>
struct
GridwiseGemmPipelineInterwave_v1
;
struct
GridwiseGemmPipelineInterwave_v1
;
...
@@ -349,7 +749,7 @@ struct GridwiseGemmPipelineInterwave_v1<1>
...
@@ -349,7 +749,7 @@ struct GridwiseGemmPipelineInterwave_v1<1>
// Note: 2 stage prefetch not optimized for inter-wave loop scheduler
// Note: 2 stage prefetch not optimized for inter-wave loop scheduler
template
<
>
template
<
>
struct
GridwiseGemmPipelineInterwave_v1
<
2
>
:
public
GridwiseGemmPipeline_v1
<
2
>
struct
GridwiseGemmPipelineInterwave_v1
<
2
>
:
public
GridwiseGemmPipeline_v1
<
2
,
true
,
true
>
{
{
};
};
...
@@ -359,7 +759,7 @@ constexpr auto GridwiseGemmPipeline_v1_Selector()
...
@@ -359,7 +759,7 @@ constexpr auto GridwiseGemmPipeline_v1_Selector()
{
{
if
constexpr
(
LoopSched
==
LoopScheduler
::
Default
)
if
constexpr
(
LoopSched
==
LoopScheduler
::
Default
)
{
{
return
GridwiseGemmPipeline_v1
<
NumPrefetch
>
{};
return
GridwiseGemmPipeline_v1
<
NumPrefetch
,
true
,
true
>
{};
}
}
else
if
constexpr
(
LoopSched
==
LoopScheduler
::
Interwave
)
else
if
constexpr
(
LoopSched
==
LoopScheduler
::
Interwave
)
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp
View file @
56de337f
...
@@ -93,7 +93,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
...
@@ -93,7 +93,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
,
true
,
true
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
56de337f
...
@@ -18,11 +18,11 @@
...
@@ -18,11 +18,11 @@
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatA
,
typename
ADataType
,
typename
FloatB
,
typename
BDataType
,
typename
FloatC
,
typename
CDataType
,
typename
AGridDesc
_K0_M_K1
,
typename
AGridDesc
,
typename
BGridDesc
_K0_N_K1
,
typename
BGridDesc
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
...
@@ -33,31 +33,27 @@ __global__ void
...
@@ -33,31 +33,27 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_wmma
(
kernel_gemm_wmma
(
const
ADataType
*
__restrict__
p_a_grid
,
const
FloatA
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
CDataType
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc
a_grid_desc
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
BGridDesc
b_grid_desc
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
AElementwiseOperation
a_element_op
,
// const
const
BElementwiseOperation
b_element_op
,
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
const
CElementwiseOperation
c_element_op
,
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
const
Block2CTileMap
block_2_ctile_map
)
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
__shared__
char
p_shared
[
GridwiseGemm
::
Get
SharedMem
oryNumberOfByte
()
];
__shared__
char
p_shared
[
GridwiseGemm
::
SharedMem
Trait
::
lds_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_shared
,
p_shared
,
a_grid_desc
_k0_m_k1
,
a_grid_desc
,
b_grid_desc
_k0_n_k1
,
b_grid_desc
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -67,8 +63,8 @@ __global__ void
...
@@ -67,8 +63,8 @@ __global__ void
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
p_c_grid
;
ignore
=
a_grid_desc
_k0_m_k1
;
ignore
=
a_grid_desc
;
ignore
=
b_grid_desc
_k0_n_k1
;
ignore
=
b_grid_desc
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
...
@@ -78,21 +74,21 @@ __global__ void
...
@@ -78,21 +74,21 @@ __global__ void
}
}
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
ADataType
,
typename
FloatB
,
typename
BDataType
,
typename
FloatAcc
,
typename
AccDataType
,
typename
Float
CShuffle
,
typename
CShuffle
DataType
,
typename
FloatC
,
typename
CDataType
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc
_K0_M_K1
,
typename
AGridDesc
,
typename
BGridDesc
_K0_N_K1
,
typename
BGridDesc
,
typename
CGridDesc_M_N
,
typename
CGridDesc_M_N
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
index_t
MPerBlock
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
NPerBlock
,
index_t
K
0
PerBlock
,
index_t
KPerBlock
,
index_t
MPerWmma
,
index_t
MPerWmma
,
index_t
NPerWmma
,
index_t
NPerWmma
,
index_t
K1Value
,
index_t
K1Value
,
...
@@ -105,6 +101,7 @@ template <index_t BlockSize,
...
@@ -105,6 +101,7 @@ template <index_t BlockSize,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
AEnableLds
,
bool
ABlockLdsExtraM
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferThreadClusterArrangeOrder
,
...
@@ -113,6 +110,7 @@ template <index_t BlockSize,
...
@@ -113,6 +110,7 @@ template <index_t BlockSize,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BEnableLds
,
bool
BBlockLdsExtraN
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
...
@@ -121,7 +119,7 @@ template <index_t BlockSize,
...
@@ -121,7 +119,7 @@ template <index_t BlockSize,
index_t
NumGemmKPrefetchStage
=
1
,
index_t
NumGemmKPrefetchStage
=
1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemm_
k0mk1_k0nk1_mn_w
mma
struct
GridwiseGemm_
W
mma
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -132,103 +130,277 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -132,103 +130,277 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
//
K1 should be Number<...>
//
FIX ME: To be deprecated
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
WmmaK
=
K1
==
16
?
32
:
16
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
using
GridwiseGemmPipe
=
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
,
AEnableLds
,
BEnableLds
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
// Describe how data store to (LDS/VGPR) buffer from Global memory
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor
()
{
{
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
a_block_desc
=
[
&
]()
{
if
constexpr
(
AEnableLds
)
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
{
return
make_naive_tensor_descriptor
(
// K0->M->K1 Per Block
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
constexpr
auto
max_lds_align
=
K1
;
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}
}
else
else
{
{
return
make_naive_tensor_descriptor_aligned
(
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1
;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
Number
<
MRepeat
>
{},
I1
,
Number
<
K0PerWmma
>
{},
I1
,
I1
,
K1
),
make_tuple
(
Number
<
MRepeat
>
{}
*
Number
<
K0PerWmma
>
{}
*
K1
,
Number
<
K0PerWmma
>
{}
*
K1
,
Number
<
K0PerWmma
>
{}
*
K1
,
K1
,
K1
,
K1
,
I1
));
}
}
}();
}();
return
a_block_desc
_k0perblock_mperblock_k1
;
return
a_block_desc
;
}
}
__host__
__device__
static
constexpr
auto
Get
BBlockDescriptor
_K0PerBlock_NPerBlock_K1
()
__host__
__device__
static
constexpr
auto
Make
BBlockDescriptor
()
{
{
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
b_block_desc
=
[
&
]()
{
if
constexpr
(
BEnableLds
)
{
// K0->N->K1 Per Block
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
constexpr
auto
max_lds_align
=
K1
;
// B matrix in LDS memory, dst of blockwise copy
if
constexpr
(
BBlockLdsExtraN
)
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
[
&
]()
{
{
if
constexpr
(
BBlockLdsExtraN
)
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}
else
{
{
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1
;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
KWmmaPerblock
>
{},
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
Number
<
NRepeat
>
{},
I1
,
Number
<
K0PerWmma
>
{},
I1
,
I1
,
K1
),
make_tuple
(
Number
<
NRepeat
>
{}
*
Number
<
K0PerWmma
>
{}
*
K1
,
Number
<
K0PerWmma
>
{}
*
K1
,
Number
<
K0PerWmma
>
{}
*
K1
,
K1
,
K1
,
K1
,
I1
));
}
}();
return
b_block_desc
;
}
__host__
__device__
static
constexpr
auto
MakeABlockSliceCopyStep
()
{
constexpr
auto
a_block_copy_step
=
[
&
]()
{
if
constexpr
(
AEnableLds
)
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
return
make_multi_index
(
K0PerBlock
,
0
,
0
);
}
}
else
else
{
{
return
make_naive_tensor_descriptor_aligned
(
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
return
make_multi_index
(
KWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
,
0
);
}
}
}();
}();
return
b
_block_
desc_k0perblock_nperblock_k1
;
return
a
_block_
copy_step
;
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeBBlockSliceCopyStep
()
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
()
{
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
constexpr
auto
b_block_copy_step
=
[
&
]()
{
constexpr
index_t
NWave
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
if
constexpr
(
BEnableLds
)
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
constexpr
auto
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
=
return
make_multi_index
(
K0PerBlock
,
0
,
0
);
make_naive_tensor_descriptor_packed
(
}
make_tuple
(
I1
,
else
Number
<
CShuffleMRepeatPerShuffle
*
MWave
*
MPerWmma
>
{},
{
I1
,
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
Number
<
CShuffleNRepeatPerShuffle
*
NWave
*
NPerWmma
>
{}));
return
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
;
return
make_multi_index
(
KWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
,
0
);
}
}();
return
b_block_copy_step
;
}
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
// Describe how data read from (LDS/VGPR) buffer
template
<
typename
ABlockDesc_
>
__host__
__device__
static
constexpr
auto
MakeAWaveDescriptor
(
const
ABlockDesc_
&
)
{
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
constexpr
auto
a_wave_desc
=
[
&
]()
{
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
if
constexpr
(
AEnableLds
)
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
A_KRow
=
I1
;
return
transform_tensor_descriptor
(
ABlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
>
{},
A_KRow
)),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
5
>
{}));
}
else
{
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr
auto
KWmma
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
K0PerWmma
=
ABlockDesc_
{}.
GetLength
(
I3
);
constexpr
auto
A_KRow
=
ABlockDesc_
{}.
GetLength
(
I4
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I6
);
// Err: merge transform cause non-constexpr issue
// return transform_tensor_descriptor(
// ABlockDesc_{},
// make_tuple(make_merge_transform(make_tuple(Number<KWmma>{}, I1)),
// make_pass_through_transform(Number<MRepeat>{}),
// make_pass_through_transform(I1),
// make_pass_through_transform(I1),
// make_pass_through_transform(Number<A_K1>{})),
// make_tuple(Sequence<0, 3>{},
// Sequence<1>{},
// Sequence<2>{},
// Sequence<4>{},
// Sequence<5>{}),
// make_tuple(
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{},
// Sequence<4>{}));
// Workaround, Freeze transform
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KWmma
*
K0PerWmma
>
{},
Number
<
MRepeat
>
{},
I1
,
Number
<
A_KRow
>
{},
I1
,
Number
<
A_K1
>
{}));
}
}();
constexpr
auto
max_lds_align
=
K1
;
return
a_wave_desc
;
}
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
template
<
typename
BBlockDesc_
>
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
__host__
__device__
static
constexpr
auto
MakeBWaveDescriptor
(
const
BBlockDesc_
&
)
{
constexpr
auto
b_wave_desc
=
[
&
]()
{
if
constexpr
(
BEnableLds
)
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_KRow
=
I1
;
return
transform_tensor_descriptor
(
BBlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{},
B_KRow
)),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
5
>
{}));
}
else
{
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr
auto
KWmma
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
K0PerWmma
=
BBlockDesc_
{}.
GetLength
(
I3
);
constexpr
auto
B_KRow
=
BBlockDesc_
{}.
GetLength
(
I4
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I6
);
// Workaround, Freeze transform
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KWmma
*
K0PerWmma
>
{},
Number
<
NRepeat
>
{},
I1
,
Number
<
B_KRow
>
{},
I1
,
Number
<
B_K1
>
{}));
}
}();
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
return
b_wave_desc
;
b_block_desc_k0perblock_nperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
}
return
(
a_block_space_size_aligned
*
sizeof
(
FloatA
)
+
__host__
__device__
static
constexpr
auto
b_block_space_size_aligned
*
sizeof
(
FloatB
));
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
()
{
constexpr
auto
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMRepeatPerShuffle
*
MWaves
*
MPerWmma
>
{},
I1
,
Number
<
CShuffleNRepeatPerShuffle
*
NWaves
*
NPerWmma
>
{}));
return
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
;
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc
&
b_grid_desc
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
"wrong! K1 need to be known at compile-time"
);
...
@@ -237,23 +409,66 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -237,23 +409,66 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
(
NPerBlock
%
(
NRepeat
*
NPerWmma
))
==
0
,
(
NPerBlock
%
(
NRepeat
*
NPerWmma
))
==
0
,
"Invalid tuning param!"
);
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
GetAProblemsizeMK
=
[
&
]()
{
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
if
constexpr
(
AEnableLds
)
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
{
return
make_tuple
(
a_grid_desc
.
GetLength
(
I1
),
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I2
));
}
else
{
return
make_tuple
(
a_grid_desc
.
GetLength
(
I1
)
*
a_grid_desc
.
GetLength
(
I2
)
*
a_grid_desc
.
GetLength
(
I5
),
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I3
)
*
a_grid_desc
.
GetLength
(
I4
)
*
a_grid_desc
.
GetLength
(
I6
));
}
};
const
auto
GetBProblemsizeNK
=
[
&
]()
{
if
constexpr
(
BEnableLds
)
{
return
make_tuple
(
b_grid_desc
.
GetLength
(
I1
),
b_grid_desc
.
GetLength
(
I0
)
*
b_grid_desc
.
GetLength
(
I2
));
}
else
{
return
make_tuple
(
b_grid_desc
.
GetLength
(
I1
)
*
b_grid_desc
.
GetLength
(
I2
)
*
b_grid_desc
.
GetLength
(
I5
),
b_grid_desc
.
GetLength
(
I0
)
*
b_grid_desc
.
GetLength
(
I3
)
*
b_grid_desc
.
GetLength
(
I4
)
*
b_grid_desc
.
GetLength
(
I6
));
}
};
const
auto
M
=
GetAProblemsizeMK
()[
I0
];
const
auto
N
=
GetBProblemsizeNK
()[
I0
];
const
auto
K
=
GetAProblemsizeMK
()[
I1
];
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc_m_n
.
GetLength
(
I1
)
&&
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc_m_n
.
GetLength
(
I1
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K
==
GetBProblemsizeNK
()[
I1
]))
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)))
{
printf
(
"A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d
\n
"
,
GetAProblemsizeMK
()[
I0
],
GetAProblemsizeMK
()[
I1
],
GetBProblemsizeNK
()[
I0
],
GetBProblemsizeNK
()[
I1
],
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
printf
(
"GridwiseOp err: ProblemSize check"
);
return
false
;
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
{
printf
(
"GridwiseOp err: ProblemSize division"
);
return
false
;
return
false
;
}
// check gridwise gemm pipeline
// check gridwise gemm pipeline
const
auto
num_k_loop
=
K
0
/
K
0
PerBlock
;
const
auto
num_k_loop
=
K
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
{
printf
(
"GridwiseOp err: Pipeline not support this k_loop"
);
return
false
;
return
false
;
}
}
...
@@ -265,8 +480,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -265,8 +480,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
if
(
!
(
a_grid_desc
_k0_m_k1
.
GetElementSpaceSize
()
*
sizeof
(
FloatA
)
<=
TwoGB
&&
if
(
!
(
a_grid_desc
.
GetElementSpaceSize
()
*
sizeof
(
ADataType
)
<=
TwoGB
&&
b_grid_desc
_k0_n_k1
.
GetElementSpaceSize
()
*
sizeof
(
FloatB
)
<=
TwoGB
))
b_grid_desc
.
GetElementSpaceSize
()
*
sizeof
(
BDataType
)
<=
TwoGB
))
{
{
return
false
;
return
false
;
}
}
...
@@ -275,7 +490,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -275,7 +490,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
{
const
index_t
num_loop
=
K
/
(
K0
PerBlock
*
K1
)
;
const
index_t
num_loop
=
K
/
K
PerBlock
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
}
...
@@ -313,13 +528,44 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -313,13 +528,44 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
using
DefaultBlock2CTileMap
=
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
struct
SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static
constexpr
auto
max_lds_align
=
K1
;
static
constexpr
auto
a_block_space_size_aligned
=
AEnableLds
?
math
::
integer_least_multiple
(
MakeABlockDescriptor
().
GetElementSpaceSize
(),
max_lds_align
)
:
0
;
static
constexpr
auto
b_block_space_size_aligned
=
BEnableLds
?
math
::
integer_least_multiple
(
MakeBBlockDescriptor
().
GetElementSpaceSize
(),
max_lds_align
)
:
0
;
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b_block_space_offset
=
a_block_space_size_aligned
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_space_size
=
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
()
.
GetElementSpaceSize
();
static
constexpr
auto
c_shuffle_block_space_offset
=
0
;
static
constexpr
auto
lds_size
=
math
::
max
(
c_shuffle_block_space_size
*
sizeof
(
CShuffleDataType
),
a_block_space_size_aligned
*
sizeof
(
ADataType
)
+
b_block_space_size_aligned
*
sizeof
(
BDataType
));
};
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
CDataType
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AGridDesc
_K0_M_K1
&
a_grid_desc
_k0_m_k1
,
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc
_K0_N_K1
&
b_grid_desc
_k0_n_k1
,
const
BGridDesc
&
b_grid_desc
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
...
@@ -331,9 +577,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -331,9 +577,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/*******************************************************************************/
/*******************************************************************************/
// Memory buffer zone.
// Memory buffer zone.
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc
_k0_m_k1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc
_k0_n_k1
.
GetElementSpaceSize
());
p_b_grid
,
b_grid_desc
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
...
@@ -351,24 +597,41 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -351,24 +597,41 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
/*******************************************************************************/
/*******************************************************************************/
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
// BlockLevel, A/B Matrix ThreadMapping in WMMA Source buffer, As Destinaion of BlockWise_Copy
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
K
=
[
&
](){
constexpr
auto
max_lds_align
=
K1
;
if
constexpr
(
AEnableLds
){
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
return
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I2
);
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
}
// A matrix blockwise copy
else
{
auto
a_blockwise_copy
=
return
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I3
)
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
*
a_grid_desc
.
GetLength
(
I4
)
*
a_grid_desc
.
GetLength
(
I6
);
}
}();
constexpr
auto
a_block_desc
=
MakeABlockDescriptor
();
constexpr
auto
b_block_desc
=
MakeBBlockDescriptor
();
auto
a_block_trait
=
[
&
](){
// A matrix blockwise copy
if
constexpr
(
AEnableLds
)
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ADataType
*>
(
p_shared
),
SharedMemTrait
::
a_block_space_size_aligned
);
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
/* typename SrcElementwiseOperation, */
AElementwiseOperation
,
/* typename SrcElementwiseOperation, */
AElementwiseOperation
,
/* typename DstElementwiseOperation, */
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
/* typename DstElementwiseOperation, */
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
/* InMemoryDataOperationEnum DstInMemOp, */
InMemoryDataOperationEnum
::
Set
,
/* InMemoryDataOperationEnum DstInMemOp, */
InMemoryDataOperationEnum
::
Set
,
/* typename BlockSliceLengths, */
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
/* typename BlockSliceLengths, */
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
/* typename ThreadClusterLengths, */
ABlockTransferThreadClusterLengths_K0_M_K1
,
/* typename ThreadClusterLengths, */
ABlockTransferThreadClusterLengths_K0_M_K1
,
/* typename ThreadClusterArrangeOrder, */
ABlockTransferThreadClusterArrangeOrder
,
/* typename ThreadClusterArrangeOrder, */
ABlockTransferThreadClusterArrangeOrder
,
/* typename SrcData, */
FloatA
,
/* typename SrcData, */
ADataType
,
/* typename DstData, */
FloatA
,
/* typename DstData, */
ADataType
,
/* typename SrcDesc, */
decltype
(
a_grid_desc
_k0_m_k1
),
/* typename SrcDesc, */
decltype
(
a_grid_desc
),
/* typename DstDesc, */
decltype
(
a_block_desc
_k0perblock_mperblock_k1
),
/* typename DstDesc, */
decltype
(
a_block_desc
),
/* typename SrcDimAccessOrder, */
ABlockTransferSrcAccessOrder
,
/* typename SrcDimAccessOrder, */
ABlockTransferSrcAccessOrder
,
/* typename DstDimAccessOrder, */
Sequence
<
0
,
1
,
2
>
,
/* typename DstDimAccessOrder, */
Sequence
<
0
,
1
,
2
>
,
/* index_t SrcVectorDim, */
ABlockTransferSrcVectorDim
,
/* index_t SrcVectorDim, */
ABlockTransferSrcVectorDim
,
...
@@ -378,99 +641,197 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -378,99 +641,197 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/* index_t SrcScalarStrideInVector, */
1
,
/* index_t SrcScalarStrideInVector, */
1
,
/* index_t DstScalarStrideInVector, */
1
,
/* index_t DstScalarStrideInVector, */
1
,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */
AThreadTransferSrcResetCoordinateAfterRun
,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */
AThreadTransferSrcResetCoordinateAfterRun
,
/* bool ThreadTransferDstResetCoordinateAfterRun, */
true
>
(
/* bool ThreadTransferDstResetCoordinateAfterRun, */
true
,
a_grid_desc_k0_m_k1
,
NumGemmKPrefetchStage
>
(
a_grid_desc
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_element_op
,
a_block_desc
_k0perblock_mperblock_k1
,
a_block_desc
,
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
return
make_tuple
(
a_block_buf
,
a_blockwise_copy
);
auto
b_blockwise_copy
=
}
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
else
BElementwiseOperation
,
{
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
// Thread-wise copy
InMemoryDataOperationEnum
::
Set
,
// KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
BBlockTransferThreadClusterLengths_K0_N_K1
,
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1Value
;
BBlockTransferThreadClusterArrangeOrder
,
auto
a_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
>
(
FloatB
,
a_block_desc
.
GetElementSpaceSize
());
FloatB
,
decltype
(
b_grid_desc_k0_n_k1
),
// Limitation: NumDim of Src and Dst descriptor should be identical
decltype
(
b_block_desc_k0perblock_nperblock_k1
),
auto
a_blockwise_copy
=
BBlockTransferSrcAccessOrder
,
ThreadwiseTensorSliceTransfer_v2
<
ADataType
,
Sequence
<
0
,
1
,
2
>
,
ADataType
,
BBlockTransferSrcVectorDim
,
decltype
(
a_grid_desc
),
2
,
decltype
(
a_block_desc
),
BBlockTransferSrcScalarPerVector
,
Sequence
<
Number
<
KWmmaPerBlock
>
{},
BBlockTransferDstScalarPerVector_K1
,
Number
<
MRepeat
>
{},
1
,
I1
,
1
,
Number
<
K0PerWmma
>
{},
BThreadTransferSrcResetCoordinateAfterRun
,
I1
,
true
>
(
I1
,
b_grid_desc_k0_n_k1
,
Number
<
K1Value
>
{}
>
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
b_element_op
,
6
,
b_block_desc_k0perblock_nperblock_k1
,
ABlockTransferSrcScalarPerVector
,
make_multi_index
(
0
,
0
,
0
),
AThreadTransferSrcResetCoordinateAfterRun
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
true
>
(
a_grid_desc
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
/
(
MWaves
*
MPerWmma
),
get_thread_local_1d_id
()
/
32
,
0
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
get_thread_local_1d_id
()
%
16
,
0
));
return
make_tuple
(
a_block_buf
,
a_blockwise_copy
);
}
};
auto
b_block_trait
=
[
&
](){
if
constexpr
(
BEnableLds
)
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
BDataType
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
SharedMemTrait
::
b_block_space_size_aligned
);
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BDataType
,
BDataType
,
decltype
(
b_grid_desc
),
decltype
(
b_block_desc
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumGemmKPrefetchStage
>
(
b_grid_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
return
make_tuple
(
b_block_buf
,
b_blockwise_copy
);
}
else
{
// Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1Value
;
auto
b_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
BDataType
>
(
b_block_desc
.
GetElementSpaceSize
());
// Limitation: NumDim of Src and Dst descriptor should be identical
auto
b_blockwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
BDataType
,
BDataType
,
decltype
(
b_grid_desc
),
decltype
(
b_block_desc
),
Sequence
<
Number
<
KWmmaPerBlock
>
{},
Number
<
NRepeat
>
{},
I1
,
Number
<
K0PerWmma
>
{},
I1
,
I1
,
Number
<
K1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
6
,
BBlockTransferSrcScalarPerVector
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_grid_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
/
(
NWaves
*
NPerWmma
),
get_thread_local_1d_id
()
/
32
,
0
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
get_thread_local_1d_id
()
%
16
,
0
));
return
make_tuple
(
b_block_buf
,
b_blockwise_copy
);
}
};
auto
a_block_buf
=
a_block_trait
()[
I0
];
auto
a_blockwise_copy
=
a_block_trait
()[
I1
];
auto
b_block_buf
=
b_block_trait
()[
I0
];
auto
b_blockwise_copy
=
b_block_trait
()[
I1
];
/*******************************************************************************/
/*******************************************************************************/
// GEMM
// GEMM
constexpr
auto
WmmaK
=
16
;
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
auto
blockwise_gemm
=
auto
blockwise_gemm
=
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
<
BlockSize
,
BlockwiseGemmWMMA
<
BlockSize
,
FloatA
,
ADataType
,
FloatB
,
BDataType
,
FloatAcc
,
AccDataType
,
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
decltype
(
MakeAWaveDescriptor
(
a_block_desc
)),
decltype
(
b_block_desc_k0perblock_nperblock_k1
),
decltype
(
MakeBWaveDescriptor
(
b_block_desc
)),
MPerWmma
,
MPerBlock
,
NPerWmma
,
NPerBlock
,
MRepeat
,
KPerBlock
,
NRepeat
,
MPerWmma
,
KPack
>
{};
NPerWmma
,
MRepeat
,
NRepeat
,
KPack
,
AEnableLds
,
BEnableLds
>
{};
// Prepare Register for C matrix
// Prepare Register for C matrix
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
/*******************************************************************************/
/*******************************************************************************/
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
// LDS allocation for A and B: be careful of alignment
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatA
*>
(
p_shared
),
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_k0perblock_nperblock_k1
.
GetElementSpaceSize
());
// Shift Per SUB_K
// Shift Per SUB_K
constexpr
auto
a_block_slice_copy_step
=
m
ake
_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
M
ake
ABlockSliceCopyStep
(
);
constexpr
auto
b_block_slice_copy_step
=
m
ake
_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
M
ake
BBlockSliceCopyStep
(
);
// gridwise GEMM pipeline
// gridwise GEMM pipeline
const
index_t
K
0
BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K
0
/
K
0
PerBlock
);
const
index_t
KBlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K
/
KPerBlock
);
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc
_k0_m_k1
,
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc
,
a_block_desc
_k0perblock_mperblock_k1
,
a_block_desc
,
a_blockwise_copy
,
a_blockwise_copy
,
a_grid_buf
,
a_grid_buf
,
a_block_buf
,
a_block_buf
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
b_grid_desc
_k0_n_k1
,
b_grid_desc
,
b_block_desc
_k0perblock_nperblock_k1
,
b_block_desc
,
b_blockwise_copy
,
b_blockwise_copy
,
b_grid_buf
,
b_grid_buf
,
b_block_buf
,
b_block_buf
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
blockwise_gemm
,
blockwise_gemm
,
c_thread_buf
,
c_thread_buf
,
K
0
BlockMainLoop
);
KBlockMainLoop
);
/*******************************************************************************/
/*******************************************************************************/
// write out to C, implement shuffle
// write out to C, implement shuffle
{
{
// C mapping in single thread.
constexpr
auto
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
constexpr
auto
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
blockwise_gemm
.
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
blockwise_gemm
.
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
//
This API Provide All dimensio
n
(
si
ze) you need
//
C mapping i
n si
ngle block
constexpr
auto
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
=
constexpr
auto
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
=
blockwise_gemm
.
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
blockwise_gemm
.
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
...
@@ -485,8 +846,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -485,8 +846,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
();
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
Float
CShuffle
*>
(
p_shared
)
,
static_cast
<
CShuffle
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
c_shuffle_block_space_offset
,
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.
GetElementSpaceS
ize
()
);
SharedMemTrait
::
c_shuffle_block_space_s
ize
);
constexpr
auto
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
transform_tensor_descriptor
(
constexpr
auto
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
,
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
,
...
@@ -532,8 +893,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -532,8 +893,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
// shuffle: threadwise copy C from VGPR to LDS
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
Float
CShuffle
,
CShuffle
DataType
,
decltype
(
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
),
decltype
(
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
),
decltype
(
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
),
decltype
(
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
@@ -571,8 +932,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -571,8 +932,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
CShuffleNRepeatPerShuffle
*
NWave
*
NPerWmma
>
,
// BlockSliceLengths,
CShuffleNRepeatPerShuffle
*
NWave
*
NPerWmma
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Float
CShuffle
,
// typename SrcData,
CShuffle
DataType
,
// typename SrcData,
FloatC
,
// typename DstData,
CDataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
),
decltype
(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
...
@@ -636,6 +997,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -636,6 +997,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
if
constexpr
(
access_id
<
num_access
-
1
)
if
constexpr
(
access_id
<
num_access
-
1
)
{
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
// move on C
// move on C
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
56de337f
...
@@ -1333,4 +1333,139 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
...
@@ -1333,4 +1333,139 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
ElementwiseOperation
element_op_
;
ElementwiseOperation
element_op_
;
};
};
// Specilized for WMMA
// A single Wave32 is composed by double row
// Data exchange allowed between these two rows
// This RowLane Dst buf will be filled from two Src buf
// SrcA: From specific thread buffer hold by This RowLane on This Row
// SrcB: From specific thread buffer hold by This RowLane on The other Row
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
ElementwiseOperation
,
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
DstVectorDim
,
index_t
DstScalarPerVector
,
uint32_t
LowEightRowlaneIdx
,
uint32_t
HighEightRowLaneIdx
,
bool
IntraRowSwizzlePerm
,
typename
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
(
const
Index
&
src_idx
)
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc need to known at compile-time"
);
static_assert
(
SliceLengths
::
At
(
Number
<
DstVectorDim
>
{})
%
DstScalarPerVector
==
0
,
"wrong! Not divisible"
);
ignore
=
src_idx
;
}
template
<
typename
SrcSliceOriginIdx
,
typename
DstSliceOriginIdx
,
typename
SrcBuffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcSliceOriginIdx
&
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
,
const
DstSliceOriginIdx
&
,
DstBuffer
&
dst_buf
)
const
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc need to known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
SrcSliceOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
DstSliceOriginIdx
>>::
value
,
"wrong! SliceOrigin need to known at compile-time"
);
static_assert
(
SrcBuffer
::
IsStaticBuffer
()
&&
DstBuffer
::
IsStaticBuffer
(),
"wrong! Buffer need to be StaticBuffer"
);
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr
auto
src_desc
=
remove_cvref_t
<
SrcDesc
>
{};
constexpr
auto
dst_desc
=
remove_cvref_t
<
DstDesc
>
{};
constexpr
auto
src_slice_origin_idx
=
to_multi_index
(
SrcSliceOriginIdx
{});
constexpr
auto
dst_slice_origin_idx
=
to_multi_index
(
DstSliceOriginIdx
{});
// scalar per access on each dim
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_scalar_step_in_vector
=
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
DstVectorDim
>
{},
Number
<
nDim
>
{});
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
remove_cv_t
<
decltype
(
dst_scalar_per_access
)
>>
;
static_assert
(
DstScalarPerVector
==
SpaceFillingCurve
::
ScalarPerVector
,
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"
);
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
idx_1d
)
{
constexpr
auto
idx_md
=
SpaceFillingCurve
::
GetIndex
(
idx_1d
);
// copy data from src_buf into dst_vector
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
// src_desc error, non constexpr, caused by merge transform
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
SrcData
v_this_row
,
v_theother_row
;
// int type temp value due to intrinsic requirement
int
temp
=
0
;
// apply element-wise operation
element_op_
(
v_this_row
,
src_buf
[
Number
<
src_offset
>
{}]);
// apply intra-row permute.
if
constexpr
(
IntraRowSwizzlePerm
)
{
temp
=
__builtin_amdgcn_permlane16
(
temp
,
type_convert_sp
<
int
>
(
v_this_row
),
0xb3a29180
,
0xf7e6d5c4
,
1
,
0
);
v_this_row
=
type_convert_sp
<
SrcData
>
(
temp
);
}
// apply inter-row permute.
temp
=
__builtin_amdgcn_permlanex16
(
temp
,
type_convert_sp
<
int
>
(
v_this_row
),
LowEightRowlaneIdx
,
HighEightRowLaneIdx
,
1
,
0
);
v_theother_row
=
type_convert_sp
<
SrcData
>
(
temp
);
if
(
get_thread_local_1d_id
()
%
32
<
16
)
{
// apply type convert
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert_sp
<
DstData
>
(
v_this_row
);
dst_buf
(
Number
<
dst_offset
+
DstScalarPerVector
>
{})
=
type_convert_sp
<
DstData
>
(
v_theother_row
);
}
else
{
// apply type convert
dst_buf
(
Number
<
dst_offset
+
DstScalarPerVector
>
{})
=
type_convert_sp
<
DstData
>
(
v_this_row
);
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert_sp
<
DstData
>
(
v_theother_row
);
}
});
});
}
ElementwiseOperation
element_op_
{};
};
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp
0 → 100644
View file @
56de337f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
namespace
ck
{
namespace
detail
{
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
template
<
index_t
SrcVectorDim
,
index_t
SrcScalarPerVector
,
index_t
DstVectorDim
,
index_t
DstScalarPerVector
>
struct
lambda_scalar_per_access_for_src_and_dst_idle
{
__host__
__device__
constexpr
auto
operator
()(
index_t
i
)
const
{
if
(
i
==
SrcVectorDim
&&
i
==
DstVectorDim
)
{
return
math
::
lcm
(
SrcScalarPerVector
,
DstScalarPerVector
);
}
else
if
(
i
==
SrcVectorDim
)
{
return
SrcScalarPerVector
;
}
else
if
(
i
==
DstVectorDim
)
{
return
DstScalarPerVector
;
}
else
{
return
1
;
}
}
};
}
// namespace detail
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
// 4. Use thread buffer
// 5. Dequantization happened between read and write.
template
<
typename
SliceLengths
,
typename
ScaleSliceLengths
,
typename
SrcElementwiseOperation
,
typename
ScaleElementwiseOperation
,
typename
DstElementwiseOperation
,
InMemoryDataOperationEnum
DstInMemOp
,
typename
SrcData
,
typename
ScaleData
,
typename
DstData
,
typename
SrcDesc
,
typename
ScaleDesc
,
typename
DstDesc
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
ScaleScalarPerVector
,
index_t
DstScalarPerVector
,
index_t
SrcScalarStrideInVector
,
index_t
ScaleScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
bool
SrcResetCoordinateAfterRun
,
// control whether to move back src coordinate after each
// RunRead(), will be fused with MoveSrcSliceWindow to
// save addr computation
bool
DstResetCoordinateAfterRun
,
// control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation
index_t
NumThreadScratch
=
1
>
struct
ThreadwiseTensorSliceTransfer_v3r1_dequant
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
SrcCoord
=
decltype
(
make_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
ScaleCoord
=
decltype
(
make_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_tensor_coordinate
(
DstDesc
{},
Index
{}));
static
constexpr
auto
I0
=
Number
<
0
>
{};
__device__
constexpr
ThreadwiseTensorSliceTransfer_v3r1_dequant
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin
,
const
SrcElementwiseOperation
&
src_element_op
,
const
ScaleDesc
&
scale_desc
,
const
Index
&
scale_slice_origin
,
const
ScaleElementwiseOperation
&
scale_element_op
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin
,
const
DstElementwiseOperation
&
dst_element_op
)
:
src_coord_
(
make_tensor_coordinate
(
src_desc
,
src_slice_origin
)),
scale_coord_
(
make_tensor_coordinate
(
scale_desc
,
scale_slice_origin
)),
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin
)),
src_element_op_
(
src_element_op
),
scale_element_op_
(
scale_element_op
),
dst_element_op_
(
dst_element_op
)
{
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
{
src_coord_
=
make_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
}
__device__
void
SetScaleSliceOrigin
(
const
ScaleDesc
&
scale_desc
,
const
Index
&
scale_slice_origin_idx
)
{
scale_coord_
=
make_tensor_coordinate
(
scale_desc
,
scale_slice_origin_idx
);
}
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
{
dst_coord_
=
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
}
template
<
typename
SrcBuffer
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
static_assert
(
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
or
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
SrcBuffer
::
type
>
,
remove_cvref_t
<
SrcData
>>::
value
,
"wrong! SrcBuffer and SrcData data type are inconsistent"
);
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
constexpr
auto
ordered_src_access_lengths
=
container_reorder_given_new2old
(
src_access_lengths
,
src_dim_access_order
);
// make forward steps
const
auto
src_forward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
forward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
src_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
src_desc
,
forward_step_idx
);
},
Number
<
nDim
>
{});
// make backward steps
const
auto
src_backward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
backward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
src_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
src_desc
,
backward_step_idx
);
},
Number
<
nDim
>
{});
// loop over tensor and copy
static_ford
<
decltype
(
ordered_src_access_lengths
)
>
{}([
&
](
auto
ordered_src_access_idx
)
{
// judge move forward or move backward
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_src_access_idx
[
I0
];
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_src_access_lengths
[
j
]
+
ordered_src_access_idx
[
j
];
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate src data index
constexpr
auto
src_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_src_access_idx
[
i
]
:
ordered_src_access_lengths
[
i
]
-
1
-
ordered_src_access_idx
[
i
];
});
return
container_reorder_given_old2new
(
ordered_idx
,
src_dim_access_order
)
*
src_scalar_per_access
;
}();
constexpr
auto
src_data_idx_seq
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
return
Number
<
src_data_idx
[
i
]
>
{};
},
Number
<
src_data_idx
.
Size
()
>
{});
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
using
src_vector_type
=
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
// copy data from src_buf into src_vector_container
auto
src_vector_container
=
src_vector_type
{
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
)};
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_tuple_
(
thread_scratch_id
)
.
template
SetAsType
<
src_vector_t
>(
src_data_idx_seq
,
src_vector_container
.
template
AsType
<
src_vector_t
>()[
I0
]);
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
move_on_dim_
(
i
)
=
ordered_src_access_idx
[
i
]
<
ordered_src_access_lengths
[
i
]
-
1
;
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
move_on_dim_
(
i
)
&=
ordered_src_access_idx
[
j
]
==
ordered_src_access_lengths
[
j
]
-
1
;
});
});
return
move_on_dim_
;
}
();
// move src coord
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
move_on_dim
[
i
])
{
if
constexpr
(
forward_sweep
[
i
])
{
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_forward_steps
[
src_dim_access_order
[
i
]]);
}
else
{
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_backward_steps
[
src_dim_access_order
[
i
]]);
}
}
});
});
// move src coordinate back to slice origin (or not)
if
constexpr
(
SrcResetCoordinateAfterRun
)
{
const
auto
src_reset_step
=
make_tensor_coordinate_step
(
src_desc
,
GetSrcCoordinateResetStep
());
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_reset_step
);
}
}
template
<
typename
ScaleBuffer
>
__device__
void
RunScaleRead
(
const
ScaleDesc
&
scale_desc
,
const
ScaleBuffer
&
scale_buf
)
{
static_assert
(
ScaleBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
or
ScaleBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
ScaleBuffer
::
type
>
,
remove_cvref_t
<
ScaleData
>>::
value
,
"wrong! ScaleBuffer and ScaleData data type are inconsistent"
);
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
scale_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
ScaleScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
scale_access_lengths
=
SliceLengths
{}
/
scale_scalar_per_access
;
constexpr
auto
scale_dim_access_order
=
SrcDimAccessOrder
{};
constexpr
auto
ordered_scale_access_lengths
=
container_reorder_given_new2old
(
scale_access_lengths
,
scale_dim_access_order
);
// make forward steps
const
auto
scale_forward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
forward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
scale_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
scale_desc
,
forward_step_idx
);
},
Number
<
nDim
>
{});
// make backward steps
const
auto
scale_backward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
backward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
scale_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
scale_desc
,
backward_step_idx
);
},
Number
<
nDim
>
{});
// loop over tensor and copy
static_ford
<
decltype
(
ordered_scale_access_lengths
)
>
{}([
&
](
auto
ordered_scale_access_idx
)
{
// judge move forward or move backward
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_scale_access_idx
[
I0
];
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_scale_access_lengths
[
j
]
+
ordered_scale_access_idx
[
j
];
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate scale data index
constexpr
auto
scale_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_scale_access_idx
[
i
]
:
ordered_scale_access_lengths
[
i
]
-
1
-
ordered_scale_access_idx
[
i
];
});
return
container_reorder_given_old2new
(
ordered_idx
,
scale_dim_access_order
)
*
scale_scalar_per_access
;
}();
constexpr
auto
scale_data_idx_seq
=
generate_sequence_v2
([
&
](
auto
i
)
{
return
Number
<
scale_data_idx
[
i
]
>
{};
},
Number
<
scale_data_idx
.
Size
()
>
{});
const
bool
is_scale_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
scale_desc
,
scale_coord_
);
using
scale_vector_type
=
vector_type_maker_t
<
ScaleData
,
ScaleScalarPerVector
>
;
using
scale_vector_t
=
typename
scale_vector_type
::
type
;
// copy data from scale_buf into scale_vector_container
auto
scale_vector_container
=
scale_vector_type
{
scale_buf
.
template
Get
<
scale_vector_t
>(
scale_coord_
.
GetOffset
(),
is_scale_valid
)};
// copy data from scale_vector_container into scale_thread_scratch_
scale_thread_scratch_
.
template
SetAsType
<
scale_vector_t
>(
scale_data_idx_seq
,
scale_vector_container
.
template
AsType
<
scale_vector_t
>()[
I0
]);
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
move_on_dim_
(
i
)
=
ordered_scale_access_idx
[
i
]
<
ordered_scale_access_lengths
[
i
]
-
1
;
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
move_on_dim_
(
i
)
&=
ordered_scale_access_idx
[
j
]
==
ordered_scale_access_lengths
[
j
]
-
1
;
});
});
return
move_on_dim_
;
}
();
// move scale coord
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
move_on_dim
[
i
])
{
if
constexpr
(
forward_sweep
[
i
])
{
move_tensor_coordinate
(
scale_desc
,
scale_coord_
,
scale_forward_steps
[
scale_dim_access_order
[
i
]]);
}
else
{
move_tensor_coordinate
(
scale_desc
,
scale_coord_
,
scale_backward_steps
[
scale_dim_access_order
[
i
]]);
}
}
});
});
// don't need to move scale coordinate back to slice origin
/*
if constexpr(SrcResetCoordinateAfterRun)
{
const auto scale_reset_step =
make_tensor_coordinate_step(scale_desc, GetScaleCoordinateResetStep());
move_tensor_coordinate(scale_desc, scale_coord_, scale_reset_step);
}
*/
}
template
<
index_t
ThreadScratchId
>
__device__
void
TransferDataFromSrcThreadScratchToDstThreadScratch
(
Number
<
ThreadScratchId
>
thread_scratch_id
)
{
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
// convert from SrcData to DstData here
dst_thread_scratch_
(
idx
)
=
type_convert
<
DstData
>
(
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
]);
});
#else
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype
if
constexpr
(
SrcVectorDim
!=
DstVectorDim
&&
((
is_same
<
half_t
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
half_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
2
==
0
&&
DstScalarPerVector
%
2
==
0
)
||
(
is_same
<
int8_t
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
int8_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
4
==
0
&&
DstScalarPerVector
%
4
==
0
)))
{
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
constexpr
index_t
num_src_vector
=
Number
<
DstScalarPerVector
>
{};
constexpr
index_t
num_dst_vector
=
Number
<
SrcScalarPerVector
>
{};
// Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
// TODO: make this logic generic for all scenario
static_assert
(
SrcVectorDim
!=
DstVectorDim
,
"wrong"
);
constexpr
auto
src_scalar_step_in_vector
=
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
SrcVectorDim
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_scalar_step_in_vector
=
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
DstVectorDim
>
{},
Number
<
nDim
>
{});
constexpr
auto
scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access_for_src_and_dst_idle
<
SrcVectorDim
,
SrcScalarPerVector
,
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
scalar_per_access
;
static_ford
<
decltype
(
access_lengths
)
>
{}([
&
](
auto
access_idx
)
{
constexpr
auto
data_idx
=
access_idx
*
scalar_per_access
;
constexpr
auto
data_idx_seq
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
return
Number
<
data_idx
[
i
]
>
{};
},
Number
<
nDim
>
{});
using
src_vector_t
=
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
;
using
dst_vector_t
=
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
;
// get DstScalarPerVector # of read-only references to src vectors from
// src_thread_scratch_
const
auto
src_vector_refs
=
generate_tie
(
[
&
](
auto
i
)
->
const
src_vector_t
&
{
// i increment corresponds to movement in DstVectorDim
return
src_thread_scratch_tuple_
[
thread_scratch_id
].
GetVectorTypeReference
(
data_idx_seq
+
i
*
dst_scalar_step_in_vector
);
},
Number
<
num_src_vector
>
{});
// get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_
auto
dst_vector_refs
=
generate_tie
(
[
&
](
auto
i
)
->
dst_vector_t
&
{
// i increment corresponds to movement in SrcVectorDim
return
dst_thread_scratch_
.
GetVectorTypeReference
(
data_idx_seq
+
i
*
src_scalar_step_in_vector
);
},
Number
<
num_dst_vector
>
{});
// do data transpose
transpose_vectors
<
SrcData
,
DstScalarPerVector
,
SrcScalarPerVector
>
{}(
src_vector_refs
,
dst_vector_refs
);
});
}
// Do fast numeric convert
constexpr
auto
scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access_for_src_and_dst_idle
<
SrcVectorDim
,
SrcScalarPerVector
,
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
scalar_per_access
;
using
src_vector_type
=
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
using
src_converted_vector_type
=
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
;
using
src_converted_vector_t
=
typename
src_converted_vector_type
::
type
;
// Vector-wise type convert
static_ford
<
decltype
(
access_lengths
)
>
{}([
&
](
auto
access_idx
)
{
auto
src_vector_container
=
src_vector_type
{
src_thread_scratch_tuple_
[
thread_scratch_id
].
template
GetAsType
<
src_vector_t
>(
access_idx
)};
auto
src_converted_vector_container
=
src_converted_vector_type
{
fast_numeric_converter
(
src_vector_container
)};
src_converted_thread_scratch_
.
template
SetAsType
<
src_converted_vector_t
>(
access_idx
,
src_converted_vector_container
.
template
AsType
<
src_converted_vector_t
>()[
I0
]);
});
// Element-scale operation, expect packed multiplication
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
DstData
dst_v
;
constexpr
auto
scale_idx
=
Sequence
<
I0
,
idx
.
At
(
1
),
I0
>
{};
// printf("Tid: %03d, scale: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<const uint16_t*>(&scale_thread_scratch_[scale_idx])));
src_element_op_
(
dst_v
,
src_converted_thread_scratch_
[
idx
]
*
scale_thread_scratch_
[
scale_idx
]);
dst_thread_scratch_
(
idx
)
=
dst_v
;
});
#endif
}
template
<
typename
DstBuffer
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
// if there is transpose, it's done here
// TODO move this elsewhere
TransferDataFromSrcThreadScratchToDstThreadScratch
(
thread_scratch_id
);
static_assert
(
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
or
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_cvref_t
<
DstData
>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
// src scalar per access on each dim
// TODO: don't use this
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_dim_access_order
=
DstDimAccessOrder
{};
constexpr
auto
ordered_dst_access_lengths
=
container_reorder_given_new2old
(
dst_access_lengths
,
dst_dim_access_order
);
// make forward steps
const
auto
dst_forward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
forward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
dst_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
dst_desc
,
forward_step_idx
);
},
Number
<
nDim
>
{});
// make backward steps
const
auto
dst_backward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
backward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
dst_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
dst_desc
,
backward_step_idx
);
},
Number
<
nDim
>
{});
// loop over tensor and copy
static_ford
<
decltype
(
ordered_dst_access_lengths
)
>
{}([
&
](
auto
ordered_dst_access_idx
)
{
// judge move forward or move backward
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_dst_access_idx
[
I0
];
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_idx
[
j
];
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate dst data index
constexpr
auto
dst_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_dst_access_idx
[
i
]
:
ordered_dst_access_lengths
[
i
]
-
1
-
ordered_dst_access_idx
[
i
];
});
return
container_reorder_given_old2new
(
ordered_idx
,
dst_dim_access_order
)
*
dst_scalar_per_access
;
}();
constexpr
auto
dst_data_idx_seq
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
return
Number
<
dst_data_idx
[
i
]
>
{};
},
Number
<
dst_data_idx
.
Size
()
>
{});
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
using
dst_vector_type
=
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
// copy data from dst_thread_scratch_ into dst_vector_container
auto
dst_vector_container
=
dst_vector_type
{
dst_thread_scratch_
.
template
GetAsType
<
dst_vector_t
>(
dst_data_idx_seq
)};
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
DstData
dst_v
;
// apply DstElementwiseOperation
dst_element_op_
(
dst_v
,
dst_vector_container
.
template
AsType
<
DstData
>()[
i
]);
dst_vector_container
.
template
AsType
<
DstData
>()(
i
)
=
dst_v
;
});
// copy data from dst_vector_container to dst_buf
dst_buf
.
template
Set
<
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
is_dst_valid
,
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
move_on_dim_
(
i
)
=
ordered_dst_access_idx
[
i
]
<
ordered_dst_access_lengths
[
i
]
-
1
;
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
move_on_dim_
(
i
)
&=
ordered_dst_access_idx
[
j
]
==
ordered_dst_access_lengths
[
j
]
-
1
;
});
});
return
move_on_dim_
;
}
();
// move dst coord
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
move_on_dim
[
i
])
{
if
constexpr
(
forward_sweep
[
i
])
{
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_forward_steps
[
dst_dim_access_order
[
i
]]);
}
else
{
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_backward_steps
[
dst_dim_access_order
[
i
]]);
}
}
});
});
// move dst coordinate back to slice origin (or not)
if
constexpr
(
DstResetCoordinateAfterRun
)
{
const
auto
dst_reset_step
=
make_tensor_coordinate_step
(
dst_desc
,
GetDstCoordinateResetStep
());
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_reset_step
);
}
}
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
constexpr
auto
ordered_src_access_lengths
=
container_reorder_given_new2old
(
src_access_lengths
,
src_dim_access_order
);
// judge move forward or move backward during the last iteration
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_src_access_lengths
[
I0
]
-
1
;
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_src_access_lengths
[
j
]
+
ordered_src_access_lengths
[
j
]
-
1
;
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate src data index after last iteration in RunRead(), if it has not being reset by
// RunRead()
constexpr
auto
src_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_src_access_lengths
[
i
]
-
1
:
0
;
});
return
container_reorder_given_old2new
(
ordered_idx
,
src_dim_access_order
)
*
src_scalar_per_access
;
}();
//
constexpr
auto
reset_src_data_step
=
[
&
]()
{
Index
reset_src_data_step_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
reset_src_data_step_
(
i
)
=
-
src_data_idx
[
i
];
});
return
reset_src_data_step_
;
}();
return
reset_src_data_step
;
}
__device__
static
constexpr
auto
GetDstCoordinateResetStep
()
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_dim_access_order
=
DstDimAccessOrder
{};
constexpr
auto
ordered_dst_access_lengths
=
container_reorder_given_new2old
(
dst_access_lengths
,
dst_dim_access_order
);
// judge move forward or move backward during the last iteration
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_dst_access_lengths
[
I0
]
-
1
;
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_lengths
[
j
]
-
1
;
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
// RunWrite()
constexpr
auto
dst_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_dst_access_lengths
[
i
]
-
1
:
0
;
});
return
container_reorder_given_old2new
(
ordered_idx
,
dst_dim_access_order
)
*
dst_scalar_per_access
;
}();
//
constexpr
auto
reset_dst_data_step
=
[
&
]()
{
Index
reset_dst_data_step_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
reset_dst_data_step_
(
i
)
=
-
dst_data_idx
[
i
];
});
return
reset_dst_data_step_
;
}();
return
reset_dst_data_step
;
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const
auto
adjusted_step_idx
=
SrcResetCoordinateAfterRun
?
src_slice_origin_step_idx
:
src_slice_origin_step_idx
+
GetSrcCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
src_desc
,
adjusted_step_idx
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_step_idx
)
{
// if dst coord was not reset by RunWrite(), then need to adjust the step here
const
auto
adjusted_step_idx
=
DstResetCoordinateAfterRun
?
dst_slice_origin_step_idx
:
dst_slice_origin_step_idx
+
GetDstCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
dst_desc
,
adjusted_step_idx
);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
adjusted_step
);
}
__device__
static
constexpr
auto
GetSrcThreadScratchDescriptor
()
{
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths_and_vector_length
=
container_push_back
(
sequence_to_tuple_of_number
(
src_access_lengths
),
Number
<
SrcScalarPerVector
>
{});
// 1st stage of transforms
constexpr
auto
desc0
=
make_naive_tensor_descriptor_packed
(
src_access_lengths_and_vector_length
);
// 2nd stage of transforms
constexpr
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
make_merge_transform_v3_division_mod
(
make_tuple
(
src_access_lengths_and_vector_length
[
i
],
src_access_lengths_and_vector_length
[
Number
<
nDim
>
{}]));
}
else
{
return
make_pass_through_transform
(
src_access_lengths_and_vector_length
[
i
]);
}
},
Number
<
nDim
>
{});
constexpr
auto
low_dim_idss
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
Sequence
<
i
.
value
,
nDim
>
{};
}
else
{
return
Sequence
<
i
.
value
>
{};
}
},
Number
<
nDim
>
{});
constexpr
auto
up_dim_idss
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
nDim
>
{});
return
transform_tensor_descriptor
(
desc0
,
transforms
,
low_dim_idss
,
up_dim_idss
);
}
__device__
static
constexpr
auto
GetScaleThreadScratchDescriptor
()
{
constexpr
auto
scale_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
ScaleScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
scale_access_lengths
=
SliceLengths
{}
/
scale_scalar_per_access
;
constexpr
auto
scale_access_lengths_and_vector_length
=
container_push_back
(
sequence_to_tuple_of_number
(
scale_access_lengths
),
Number
<
ScaleScalarPerVector
>
{});
// 1st stage of transforms
constexpr
auto
desc0
=
make_naive_tensor_descriptor_packed
(
scale_access_lengths_and_vector_length
);
// 2nd stage of transforms
constexpr
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
make_merge_transform_v3_division_mod
(
make_tuple
(
scale_access_lengths_and_vector_length
[
i
],
scale_access_lengths_and_vector_length
[
Number
<
nDim
>
{}]));
}
else
{
return
make_pass_through_transform
(
scale_access_lengths_and_vector_length
[
i
]);
}
},
Number
<
nDim
>
{});
constexpr
auto
low_dim_idss
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
Sequence
<
i
.
value
,
nDim
>
{};
}
else
{
return
Sequence
<
i
.
value
>
{};
}
},
Number
<
nDim
>
{});
constexpr
auto
up_dim_idss
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
nDim
>
{});
return
transform_tensor_descriptor
(
desc0
,
transforms
,
low_dim_idss
,
up_dim_idss
);
}
__device__
static
constexpr
auto
GetDstThreadScratchDescriptor
()
{
// 1st stage of transforms
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_access_lengths_and_vector_length
=
container_push_back
(
sequence_to_tuple_of_number
(
dst_access_lengths
),
Number
<
DstScalarPerVector
>
{});
constexpr
auto
desc0
=
make_naive_tensor_descriptor_packed
(
dst_access_lengths_and_vector_length
);
// 2nd stage of transforms
constexpr
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
DstVectorDim
)
{
return
make_merge_transform_v3_division_mod
(
make_tuple
(
dst_access_lengths_and_vector_length
[
i
],
dst_access_lengths_and_vector_length
[
Number
<
nDim
>
{}]));
}
else
{
return
make_pass_through_transform
(
dst_access_lengths_and_vector_length
[
i
]);
}
},
Number
<
nDim
>
{});
constexpr
auto
low_dim_idss
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
DstVectorDim
)
{
return
Sequence
<
i
.
value
,
nDim
>
{};
}
else
{
return
Sequence
<
i
.
value
>
{};
}
},
Number
<
nDim
>
{});
constexpr
auto
up_dim_idss
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
nDim
>
{});
return
transform_tensor_descriptor
(
desc0
,
transforms
,
low_dim_idss
,
up_dim_idss
);
}
private:
static
constexpr
auto
src_thread_scratch_desc_
=
decltype
(
GetSrcThreadScratchDescriptor
()){};
static
constexpr
auto
scale_thread_scratch_desc_
=
decltype
(
GetScaleThreadScratchDescriptor
()){};
static
constexpr
auto
dst_thread_scratch_desc_
=
decltype
(
GetDstThreadScratchDescriptor
()){};
/*
template <bool kLastDim>
struct ScaleThreadScratchDesc{};
*/
// Registers, contain raw data loaded from global buffer
using
SrcThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
SrcData
,
SrcScalarPerVector
,
decltype
(
src_thread_scratch_desc_
),
true
>
;
// Registers, contain fast converted data
using
SrcThreadConvertedScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
DstData
,
SrcScalarPerVector
,
decltype
(
src_thread_scratch_desc_
),
true
>
;
// Registers, contain scale data
using
ScaleThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
ScaleData
,
ScaleScalarPerVector
,
decltype
(
scale_thread_scratch_desc_
),
true
>
;
// Registers, contain dequantized data
using
DstThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
DstData
,
DstScalarPerVector
,
decltype
(
dst_thread_scratch_desc_
),
true
>
;
using
FastTypeConverter
=
tensor_operation
::
element_wise
::
FastNumericArrayConverter
<
SrcData
,
DstData
,
SrcScalarPerVector
>
;
StaticallyIndexedArray
<
SrcThreadScratch
,
NumThreadScratch
>
src_thread_scratch_tuple_
;
SrcThreadConvertedScratch
src_converted_thread_scratch_
;
ScaleThreadScratch
scale_thread_scratch_
;
DstThreadScratch
dst_thread_scratch_
;
FastTypeConverter
fast_numeric_converter
;
SrcCoord
src_coord_
;
ScaleCoord
scale_coord_
;
DstCoord
dst_coord_
;
const
SrcElementwiseOperation
src_element_op_
;
const
ScaleElementwiseOperation
scale_element_op_
;
const
DstElementwiseOperation
dst_element_op_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
56de337f
...
@@ -89,6 +89,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
...
@@ -89,6 +89,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_pack_number
=
1
;
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
...
@@ -100,7 +101,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
...
@@ -100,7 +101,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
// * num_acc_vgprs_per_wave alone M direction
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
// * num_subgroups alone M direction
static
constexpr
index_t
num_acc_vgprs_per_wave
=
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
m_per_wmma
*
n_per_wmma
*
acc_data_size
*
acc_pack_number
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
...
@@ -129,6 +130,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
...
@@ -129,6 +130,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_pack_number
=
1
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
// Wave mode dependent propety
...
@@ -136,7 +138,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
...
@@ -136,7 +138,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
m_per_wmma
*
n_per_wmma
*
acc_data_size
*
acc_pack_number
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
...
@@ -153,7 +155,6 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
...
@@ -153,7 +155,6 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
}
}
};
};
#ifdef CK_UNPACKED_ACC_DESC_LOGIC
template
<
index_t
WaveSize
>
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_f16_16x16x16_f16
,
struct
wmma_type
<
WmmaInstr
::
wmma_f16_16x16x16_f16
,
WaveSize
,
WaveSize
,
...
@@ -166,6 +167,7 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
...
@@ -166,6 +167,7 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
2
;
static
constexpr
index_t
acc_pack_number
=
2
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
// Wave mode dependent propety
...
@@ -173,28 +175,22 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
...
@@ -173,28 +175,22 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
m_per_wmma
*
n_per_wmma
*
acc_data_size
*
acc_pack_number
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
index_t
NPerWmma
,
index_t
Opsel
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
if
constexpr
(
wave_size
==
32
)
if
constexpr
(
wave_size
==
32
)
{
{
intrin_wmma_f16_16x16x16_f16_w32
<
MPerWmma
,
NPerWmma
,
Op
se
l
>::
Run
(
a
,
b
,
reg_c
);
intrin_wmma_f16_16x16x16_f16_w32
<
MPerWmma
,
NPerWmma
,
fal
se
>::
Run
(
a
,
b
,
reg_c
);
}
}
else
if
constexpr
(
wave_size
==
64
)
else
if
constexpr
(
wave_size
==
64
)
{
{
intrin_wmma_f16_16x16x16_f16_w64
<
MPerWmma
,
NPerWmma
,
Op
se
l
>::
Run
(
a
,
b
,
reg_c
);
intrin_wmma_f16_16x16x16_f16_w64
<
MPerWmma
,
NPerWmma
,
fal
se
>::
Run
(
a
,
b
,
reg_c
);
}
}
}
}
};
};
template
<
index_t
WaveSize
>
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_bf16_16x16x16_bf16
,
struct
wmma_type
<
WmmaInstr
::
wmma_bf16_16x16x16_bf16
,
WaveSize
,
WaveSize
,
...
@@ -207,6 +203,7 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
...
@@ -207,6 +203,7 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
2
;
static
constexpr
index_t
acc_pack_number
=
2
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
// Wave mode dependent propety
...
@@ -214,7 +211,7 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
...
@@ -214,7 +211,7 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
m_per_wmma
*
n_per_wmma
*
acc_data_size
*
acc_pack_number
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
template
<
index_t
MPerWmma
,
...
@@ -227,17 +224,15 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
...
@@ -227,17 +224,15 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
{
{
if
constexpr
(
wave_size
==
32
)
if
constexpr
(
wave_size
==
32
)
{
{
intrin_wmma_bf16_16x16x16_bf16_w32
<
MPerWmma
,
NPerWmma
,
Op
se
l
>::
Run
(
a
,
b
,
reg_c
);
intrin_wmma_bf16_16x16x16_bf16_w32
<
MPerWmma
,
NPerWmma
,
fal
se
>::
Run
(
a
,
b
,
reg_c
);
}
}
else
if
constexpr
(
wave_size
==
64
)
else
if
constexpr
(
wave_size
==
64
)
{
{
intrin_wmma_bf16_16x16x16_bf16_w64
<
MPerWmma
,
NPerWmma
,
Op
se
l
>::
Run
(
a
,
b
,
reg_c
);
intrin_wmma_bf16_16x16x16_bf16_w64
<
MPerWmma
,
NPerWmma
,
fal
se
>::
Run
(
a
,
b
,
reg_c
);
}
}
}
}
};
};
#endif
template
<
index_t
WaveSize
>
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_i32_16x16x16_iu8
,
struct
wmma_type
<
WmmaInstr
::
wmma_i32_16x16x16_iu8
,
WaveSize
,
WaveSize
,
...
@@ -250,6 +245,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
...
@@ -250,6 +245,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_pack_number
=
1
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
// Wave mode dependent propety
...
@@ -257,7 +253,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
...
@@ -257,7 +253,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
m_per_wmma
*
n_per_wmma
*
acc_data_size
*
acc_pack_number
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
template
<
index_t
MPerWmma
,
...
@@ -346,7 +342,7 @@ struct WmmaSelector
...
@@ -346,7 +342,7 @@ struct WmmaSelector
static_assert
(
selected_wmma
.
k_per_wmma
==
16
,
"WRONG! WMMA_M must equal to 16"
);
static_assert
(
selected_wmma
.
k_per_wmma
==
16
,
"WRONG! WMMA_M must equal to 16"
);
static_assert
(
selected_wmma
.
wave_size
*
selected_wmma
.
num_acc_vgprs_per_wave
*
static_assert
(
selected_wmma
.
wave_size
*
selected_wmma
.
num_acc_vgprs_per_wave
*
selected_wmma
.
acc_data_size
==
selected_wmma
.
acc_data_size
*
selected_wmma
.
acc_pack_number
==
selected_wmma
.
m_per_wmma
*
selected_wmma
.
n_per_wmma
*
4
,
selected_wmma
.
m_per_wmma
*
selected_wmma
.
n_per_wmma
*
4
,
"WRONG! Invalid Number of Accumulator Register"
);
"WRONG! Invalid Number of Accumulator Register"
);
}
}
...
@@ -358,7 +354,8 @@ template <typename src_type_a,
...
@@ -358,7 +354,8 @@ template <typename src_type_a,
index_t
MPerWmma
,
index_t
MPerWmma
,
index_t
NPerWmma
,
index_t
NPerWmma
,
index_t
KPack
,
index_t
KPack
,
bool
TransposeC
=
false
>
bool
TransposeC
=
false
,
bool
AssemblyBackend
=
false
>
struct
WmmaGemm
struct
WmmaGemm
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -369,14 +366,14 @@ struct WmmaGemm
...
@@ -369,14 +366,14 @@ struct WmmaGemm
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
using
CIndex
=
MultiIndex
<
2
>
;
using
CIndex
=
MultiIndex
<
2
>
;
using
CIndex
4
D
=
MultiIndex
<
4
>
;
using
CIndex
3
D
=
MultiIndex
<
3
>
;
__host__
__device__
constexpr
WmmaGemm
()
__host__
__device__
constexpr
WmmaGemm
()
{
{
static_assert
(
NPerWmma
==
16
&&
MPerWmma
==
16
,
static_assert
(
NPerWmma
==
16
&&
MPerWmma
==
16
,
"Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma"
);
"Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma"
);
static_assert
(
KPack
==
wmma_instr
.
k_per_wmma
,
"KPack should be k_per_wmma"
);
static_assert
(
KPack
%
wmma_instr
.
k_per_wmma
==
0
,
"KPack should be
multiple of
k_per_wmma"
);
}
}
// WMMA output supporting C = A * B
// WMMA output supporting C = A * B
...
@@ -421,9 +418,49 @@ struct WmmaGemm
...
@@ -421,9 +418,49 @@ struct WmmaGemm
Sequence
<
5
>
{}));
Sequence
<
5
>
{}));
}
}
// Transposed WMMA Output C' = B' * A'
template
<
typename
CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA
>
__host__
__device__
static
constexpr
auto
MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs
(
const
CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA
&
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
)
{
const
auto
MBlockxRepeat
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I0
);
const
auto
NBlockxRepeat
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I3
);
const
auto
MWave
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I1
);
const
auto
NWave
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I4
);
return
transform_tensor_descriptor
(
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
,
make_tuple
(
make_pass_through_transform
(
MBlockxRepeat
),
make_pass_through_transform
(
MWave
),
make_pass_through_transform
(
Number
<
wmma_instr
.
num_thread_per_subgroups
>
{}),
make_pass_through_transform
(
NBlockxRepeat
),
make_pass_through_transform
(
NWave
),
make_unmerge_transform
(
make_tuple
(
Number
<
wmma_instr
.
num_subgroups
>
{},
Number
<
wmma_instr
.
num_acc_vgprs_per_wave
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
>
{}));
}
__device__
static
constexpr
index_t
GetRegSizePerWmma
()
__device__
static
constexpr
index_t
GetRegSizePerWmma
()
{
{
return
wmma_instr
.
num_acc_vgprs_per_wave
;
return
wmma_instr
.
num_acc_vgprs_per_wave
*
wmma_instr
.
acc_pack_number
;
}
}
__device__
static
constexpr
index_t
GetWaveSize
()
{
return
wmma_instr
.
wave_size
;
}
__device__
static
constexpr
index_t
GetWaveSize
()
{
return
wmma_instr
.
wave_size
;
}
...
@@ -449,14 +486,16 @@ struct WmmaGemm
...
@@ -449,14 +486,16 @@ struct WmmaGemm
,
,
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
"(int8, int32) or (int4, int32)!"
);
"(int8, int32) or (int4, int32)!"
);
if
constexpr
(
!
TransposeC
)
static_for
<
0
,
KPack
/
wmma_instr
.
k_per_wmma
,
1
>
{}([
&
](
auto
k
)
{
{
if
constexpr
(
!
TransposeC
)
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
>(
p_a_wave
,
p_b_wave
,
p_c_thread
);
{
}
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
else
}
{
else
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
>(
p_b_wave
,
p_a_wave
,
p_c_thread
);
{
}
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
>(
p_b_wave
[
k
],
p_a_wave
[
k
],
p_c_thread
);
}
});
}
}
__device__
static
auto
GetLaneId
()
{
return
get_thread_local_1d_id
()
%
wmma_instr
.
wave_size
;
}
__device__
static
auto
GetLaneId
()
{
return
get_thread_local_1d_id
()
%
wmma_instr
.
wave_size
;
}
...
@@ -477,12 +516,12 @@ struct WmmaGemm
...
@@ -477,12 +516,12 @@ struct WmmaGemm
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
()
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
{
return
GetSwizzledLaneIdLow
();
return
TransposeC
?
GetLaneIdUnderSubGroup
()
:
GetSwizzledLaneIdLow
();
}
}
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
()
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
{
return
GetLaneIdUnderSubGroup
();
return
TransposeC
?
GetSwizzledLaneIdLow
()
:
GetLaneIdUnderSubGroup
();
}
}
__device__
static
CIndex
GetBeginOfThreadBlk
()
__device__
static
CIndex
GetBeginOfThreadBlk
()
...
@@ -493,6 +532,14 @@ struct WmmaGemm
...
@@ -493,6 +532,14 @@ struct WmmaGemm
return
TransposeC
?
CIndex
{
n_offset
,
m_offset
}
:
CIndex
{
m_offset
,
n_offset
};
return
TransposeC
?
CIndex
{
n_offset
,
m_offset
}
:
CIndex
{
m_offset
,
n_offset
};
}
}
__device__
static
CIndex3D
GetBeginOfThreadBlk3D
()
{
index_t
n_offset
=
GetLaneIdUnderSubGroup
();
index_t
m_offset
=
GetSubGroupId
();
return
TransposeC
?
CIndex3D
{
n_offset
,
m_offset
,
I0
}
:
CIndex3D
{
m_offset
,
n_offset
,
I0
};
}
static
constexpr
auto
wmma
=
static
constexpr
auto
wmma
=
WmmaSelector
<
src_type_a
,
src_type_b
,
dst_type
,
MPerWmma
,
NPerWmma
>
{};
WmmaSelector
<
src_type_a
,
src_type_b
,
dst_type
,
MPerWmma
,
NPerWmma
>
{};
static
constexpr
auto
wmma_instr
=
wmma
.
selected_wmma
;
static
constexpr
auto
wmma_instr
=
wmma
.
selected_wmma
;
...
@@ -500,7 +547,10 @@ struct WmmaGemm
...
@@ -500,7 +547,10 @@ struct WmmaGemm
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
()
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
()
{
{
return
make_tuple
(
I1
,
I1
,
Number
<
wmma_instr
.
num_acc_vgprs_per_wave
>
{});
return
make_tuple
(
I1
,
I1
,
Number
<
wmma_instr
.
num_acc_vgprs_per_wave
>
{},
Number
<
wmma_instr
.
acc_pack_number
>
{});
}
}
};
};
...
...
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp
0 → 100644
View file @
56de337f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
namespace
ck
{
namespace
tensor_operation
{
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
device
::
TensorSpecialization
TensorSpec
>
__host__
__device__
static
auto
MakeGridDescriptorPair
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
gs_ms_ns_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
gs_ms_ns_strides_vec
)
{
// if(!(gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
// gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN))
// {
// throw std::runtime_error("wrong! dimension must match input lengths");
// }
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
};
const
auto
gs_ms_ns_lengths
=
to_tuple
(
gs_ms_ns_lengths_vec
,
Number
<
0
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
const
auto
gs_ms_ns_strides
=
to_tuple
(
gs_ms_ns_strides_vec
,
Number
<
0
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
// dimension Ids for G0, G1, ...
constexpr
auto
gDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimG
,
1
>::
type
{};
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
NumDimG
,
NumDimG
+
NumDimM
,
1
>::
type
{};
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
NumDimG
+
NumDimM
,
NumDimG
+
NumDimM
+
NumDimN
,
1
>::
type
{};
// lengths for G0, G1, ...
const
auto
gLengths
=
get_container_subset
(
gs_ms_ns_lengths
,
gDimIds
);
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
gs_ms_ns_lengths
,
mDimIds
);
// lengths for N0, N1, ...
const
auto
nLengths
=
get_container_subset
(
gs_ms_ns_lengths
,
nDimIds
);
if
constexpr
(
TensorSpec
==
device
::
TensorSpecialization
::
Packed
)
{
auto
G
=
container_reduce
(
gLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
M
=
container_reduce
(
mLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
N
=
container_reduce
(
nLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
grid_desc_g_mraw_nraw
=
make_naive_tensor_descriptor
(
make_tuple
(
G
,
M
,
N
),
make_tuple
(
gs_ms_ns_strides
[
Number
<
NumDimG
-
1
>
{}],
gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
-
1
>
{}],
gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
+
NumDimN
-
1
>
{}]));
const
auto
grid_desc_mraw_nraw
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
-
1
>
{}],
gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
+
NumDimN
-
1
>
{}]));
return
std
::
make_pair
(
grid_desc_g_mraw_nraw
,
grid_desc_mraw_nraw
);
}
else
{
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
const
auto
grid_desc_gs_ms_ns
=
make_naive_tensor_descriptor
(
gs_ms_ns_lengths
,
gs_ms_ns_strides
);
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
// Note: This does not require padding as it only provides G offset calculation. Technically
// descriptor for only G is needed. Here we opt for backward compatibility purpose to return
// G_M_N
const
auto
grid_desc_g_mraw_nraw
=
transform_tensor_descriptor
(
grid_desc_gs_ms_ns
,
make_tuple
(
make_merge_transform
(
gLengths
),
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
gDimIds
,
mDimIds
,
nDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
c_ms_ns_lengths
=
to_tuple
(
gs_ms_ns_lengths_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
const
auto
c_ms_ns_strides
=
to_tuple
(
gs_ms_ns_strides_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
const
auto
grid_desc_ms_ns
=
make_naive_tensor_descriptor
(
c_ms_ns_lengths
,
c_ms_ns_strides
);
const
auto
grid_desc_mraw_nraw
=
transform_tensor_descriptor
(
grid_desc_ms_ns
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
mDimIds
-
Number
<
NumDimG
>
{},
nDimIds
-
Number
<
NumDimG
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
std
::
make_pair
(
grid_desc_g_mraw_nraw
,
grid_desc_mraw_nraw
);
}
}
template
<
typename
NumDims_G_M_N_K_O
,
// Sequence<>
typename
PerBlock_M_N_K_O
,
// Sequence<>
device
::
GemmSpecialization
GemmSpec
,
device
::
TensorSpecialization
ASpec
,
device
::
TensorSpecialization
B0Spec
,
device
::
TensorSpecialization
B1Spec
,
device
::
TensorSpecialization
CSpec
>
struct
TransformBatchedContractionContractionToBatchedGemmGemm_Wmma
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
index_t
NumDimG
=
NumDims_G_M_N_K_O
::
At
(
I0
);
static
constexpr
index_t
NumDimM
=
NumDims_G_M_N_K_O
::
At
(
I1
);
static
constexpr
index_t
NumDimN
=
NumDims_G_M_N_K_O
::
At
(
I2
);
static
constexpr
index_t
NumDimK
=
NumDims_G_M_N_K_O
::
At
(
I3
);
static
constexpr
index_t
NumDimO
=
NumDims_G_M_N_K_O
::
At
(
I4
);
static
constexpr
index_t
MPerBlock
=
PerBlock_M_N_K_O
::
At
(
I0
);
static
constexpr
index_t
NPerBlock
=
PerBlock_M_N_K_O
::
At
(
I1
);
static
constexpr
index_t
KPerBlock
=
PerBlock_M_N_K_O
::
At
(
I2
);
static
constexpr
index_t
OPerBlock
=
PerBlock_M_N_K_O
::
At
(
I3
);
static
constexpr
auto
matrix_padder
=
device
::
GemmGemmPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
,
OPerBlock
};
//
// A
//
__host__
__device__
static
auto
MakeAGridDescriptorPair
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides_vec
)
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimM
,
NumDimK
,
ASpec
>
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
);
}
// TODO: rename to G_MRaw_KRaw
__host__
__device__
static
auto
MakeAGridDescriptor_G_M_K
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides_vec
)
{
return
MakeAGridDescriptorPair
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
).
first
;
}
__host__
__device__
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
a_gs_ms_ks_strides_vec
)
{
return
matrix_padder
.
PadADescriptor_M_K
(
MakeAGridDescriptorPair
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
).
second
);
}
template
<
typename
AGridDesc_M_K
,
typename
Number
>
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
Number
&
AK1
)
{
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
typename
AGridDesc_M_K
,
typename
WmmaK
,
typename
MRepeat
,
typename
MWaves
,
typename
MPerWmma
,
typename
AK1
>
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
WmmaK
&
,
const
MRepeat
&
,
const
MWaves
&
,
const
MPerWmma
&
,
const
AK1
&
)
{
const
auto
M0
=
a_grid_desc_m_k
.
GetLength
(
I0
)
/
MPerBlock
;
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AKWmma
=
K
/
WmmaK
{};
constexpr
auto
AKRow
=
2
;
constexpr
auto
AK0PerWmma
=
WmmaK
{}
/
AKRow
/
AK1
{};
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AKWmma
,
Number
<
AK0PerWmma
>
{},
Number
<
AKRow
>
{},
AK1
{})),
make_unmerge_transform
(
make_tuple
(
M0
*
MRepeat
{},
MWaves
{},
MPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
//
// B (alias of B0)
//
__host__
__device__
static
auto
MakeB0GridDescriptorPair
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ns_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ns_ks_strides_vec
)
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimK
,
B0Spec
>
(
b0_gs_ns_ks_lengths_vec
,
b0_gs_ns_ks_strides_vec
);
}
// TODO: rename to G_MRaw_NRaw
__host__
__device__
static
auto
MakeB0GridDescriptor_G_N_K
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ns_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ns_ks_strides_vec
)
{
return
MakeB0GridDescriptorPair
(
b0_gs_ns_ks_lengths_vec
,
b0_gs_ns_ks_strides_vec
).
first
;
}
__host__
__device__
static
auto
MakeB0GridDescriptor_N_K
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ns_ks_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b0_gs_ns_ks_strides_vec
)
{
// alias of matrix_padder.PadB0Descriptor_N_K
return
matrix_padder
.
PadBDescriptor_N_K
(
MakeB0GridDescriptorPair
(
b0_gs_ns_ks_lengths_vec
,
b0_gs_ns_ks_strides_vec
).
second
);
}
template
<
typename
BGridDesc_N_K
,
typename
Number
>
__host__
__device__
static
constexpr
auto
MakeB0GridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
Number
&
BK1
)
{
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
typename
BGridDesc_L_K
,
typename
WmmaK
,
typename
LRepeat
,
typename
LWaves
,
typename
LPerWmma
,
typename
BK1
>
__host__
__device__
static
constexpr
auto
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1
(
const
BGridDesc_L_K
&
b_grid_desc_l_k
,
const
WmmaK
&
,
const
LRepeat
&
,
const
LWaves
&
,
const
LPerWmma
&
,
const
BK1
&
)
{
const
auto
L0
=
b_grid_desc_l_k
.
GetLength
(
I0
)
/
NPerBlock
;
const
auto
K
=
b_grid_desc_l_k
.
GetLength
(
I1
);
const
auto
BKWmma
=
K
/
WmmaK
{};
constexpr
auto
BKRow
=
2
;
constexpr
auto
BK0PerWmma
=
WmmaK
{}
/
BKRow
/
BK1
{};
return
transform_tensor_descriptor
(
b_grid_desc_l_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BKWmma
,
Number
<
BK0PerWmma
>
{},
Number
<
BKRow
>
{},
BK1
{})),
make_unmerge_transform
(
make_tuple
(
L0
*
LRepeat
{},
LWaves
{},
LPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
//
// B1
//
__host__
__device__
static
auto
MakeB1GridDescriptorPair
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_os_ns_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_os_ns_strides_vec
)
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimO
,
NumDimN
,
B1Spec
>
(
b1_gs_os_ns_lengths_vec
,
b1_gs_os_ns_strides_vec
);
}
// TODO: rename to G_NRaw_KRaw
__host__
__device__
static
auto
MakeB1GridDescriptor_G_N_K
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_os_ns_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_os_ns_strides_vec
)
{
return
MakeB1GridDescriptorPair
(
b1_gs_os_ns_lengths_vec
,
b1_gs_os_ns_strides_vec
).
first
;
}
__host__
__device__
static
auto
MakeB1GridDescriptor_N_K
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_os_ns_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
b1_gs_os_ns_strides_vec
)
{
// alias of matrix_padder.PadB1Descriptor_O_N
return
matrix_padder
.
PadB1Descriptor_N_K
(
MakeB1GridDescriptorPair
(
b1_gs_os_ns_lengths_vec
,
b1_gs_os_ns_strides_vec
).
second
);
}
template
<
typename
B1GridDesc_N_K
,
typename
Number
>
__host__
__device__
static
constexpr
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
B1GridDesc_N_K
&
b1_grid_desc_n_k
,
const
Number
&
B1K1
)
{
const
auto
N
=
b1_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b1_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
B1K0
=
K
/
B1K1
;
return
transform_tensor_descriptor
(
b1_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
typename
BGridDesc_N_L
,
typename
WmmaL
,
typename
NRepeat
,
typename
NWaves
,
typename
NPerWmma
,
typename
BL1
>
__host__
__device__
static
constexpr
auto
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1
(
const
BGridDesc_N_L
&
b_grid_desc_n_l
,
const
WmmaL
&
,
const
NRepeat
&
,
const
NWaves
&
,
const
NPerWmma
&
,
const
BL1
&
)
{
const
auto
N0
=
b_grid_desc_n_l
.
GetLength
(
I0
)
/
OPerBlock
;
const
auto
L
=
b_grid_desc_n_l
.
GetLength
(
I1
);
const
auto
BLWmma
=
L
/
WmmaL
{};
constexpr
auto
BLRow
=
2
;
constexpr
auto
BL0PerWmma
=
WmmaL
{}
/
BLRow
/
BL1
{};
return
transform_tensor_descriptor
(
b_grid_desc_n_l
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BLWmma
,
Number
<
BL0PerWmma
>
{},
Number
<
BLRow
>
{},
BL1
{})),
make_unmerge_transform
(
make_tuple
(
N0
*
NRepeat
{},
NWaves
{},
NPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
//
// C
//
__host__
__device__
static
auto
MakeCGridDescriptorPair
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_os_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_os_strides_vec
)
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimM
,
NumDimO
,
CSpec
>
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
);
}
// TODO: rename to G_MRaw_NRaw
__host__
__device__
static
auto
MakeCGridDescriptor_G_M_N
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_os_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_os_strides_vec
)
{
return
MakeCGridDescriptorPair
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
).
first
;
}
__host__
__device__
static
auto
MakeCGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_os_lengths_vec
,
const
std
::
array
<
index_t
,
NumDimG
+
NumDimM
+
NumDimN
>&
c_gs_ms_os_strides_vec
)
{
return
matrix_padder
.
PadCDescriptor_M_N
(
MakeCGridDescriptorPair
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
).
second
);
}
};
}
// namespace tensor_operation
}
// namespace ck
include/ck/utility/amd_buffer_addressing.hpp
View file @
56de337f
...
@@ -417,7 +417,8 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
...
@@ -417,7 +417,8 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
f8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
f8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
uint8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
"wrong! not implemented"
);
using
r_t
=
typename
vector_type
<
T
,
N
>::
type
;
using
r_t
=
typename
vector_type
<
T
,
N
>::
type
;
...
...
include/ck/utility/amd_inline_asm.hpp
View file @
56de337f
...
@@ -220,8 +220,8 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0
...
@@ -220,8 +220,8 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0
"0"
(
c0
),
"0"
(
c0
),
"1"
(
c1
));
"1"
(
c1
));
#else
#else
c0
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b0
),
c0
,
false
);
c0
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b0
),
c0
,
false
);
c1
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b1
),
c1
,
false
);
c1
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b1
),
c1
,
false
);
#endif
#endif
}
}
...
@@ -257,10 +257,10 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a,
...
@@ -257,10 +257,10 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a,
"2"
(
c2
),
"2"
(
c2
),
"3"
(
c3
));
"3"
(
c3
));
#else
#else
c0
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b0
),
c0
,
false
);
c0
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b0
),
c0
,
false
);
c1
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b1
),
c1
,
false
);
c1
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b1
),
c1
,
false
);
c2
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b2
),
c2
,
false
);
c2
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b2
),
c2
,
false
);
c3
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b3
),
c3
,
false
);
c3
=
__builtin_amdgcn_sdot4
(
bit_cast
<
int32_t
>
(
a
),
bit_cast
<
int32_t
>
(
b3
),
c3
,
false
);
#endif
#endif
}
}
...
@@ -355,17 +355,5 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
...
@@ -355,17 +355,5 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
c3
);
c3
);
}
}
// Ranged input operand
__device__
void
amd_assembly_wmma_f32_16x16x16_f16_w32
(
half16_t
a
,
half16_t
b
,
float8_t
&
c
)
{
#if defined(__gfx11__)
asm
volatile
(
"v_wmma_f32_16x16x16_f16 %0, %1, %2, %0"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
#else
ignore
=
a
;
ignore
=
b
;
ignore
=
c
;
#endif
}
}
// namespace ck
}
// namespace ck
#endif
#endif
include/ck/utility/data_type.hpp
View file @
56de337f
...
@@ -133,6 +133,13 @@ struct scalar_type<int8_t>
...
@@ -133,6 +133,13 @@ struct scalar_type<int8_t>
static
constexpr
index_t
vector_size
=
1
;
static
constexpr
index_t
vector_size
=
1
;
};
};
template
<
>
struct
scalar_type
<
uint8_t
>
{
using
type
=
uint8_t
;
static
constexpr
index_t
vector_size
=
1
;
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
template
<
>
struct
scalar_type
<
int4_t
>
struct
scalar_type
<
int4_t
>
...
@@ -1037,6 +1044,14 @@ using bf8x8_t = typename vector_type<bf8_t, 8>::type;
...
@@ -1037,6 +1044,14 @@ using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
// u8
// i8
using
uint8x2_t
=
typename
vector_type
<
uint8_t
,
2
>::
type
;
using
uint8x4_t
=
typename
vector_type
<
uint8_t
,
4
>::
type
;
using
uint8x8_t
=
typename
vector_type
<
uint8_t
,
8
>::
type
;
using
uint8x16_t
=
typename
vector_type
<
uint8_t
,
16
>::
type
;
using
uint8x32_t
=
typename
vector_type
<
uint8_t
,
32
>::
type
;
using
uint8x64_t
=
typename
vector_type
<
uint8_t
,
64
>::
type
;
template
<
typename
T
>
template
<
typename
T
>
struct
NumericLimits
struct
NumericLimits
...
...
include/ck/utility/type_convert.hpp
View file @
56de337f
...
@@ -99,6 +99,63 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
...
@@ -99,6 +99,63 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return
type_convert
<
bhalf_t
>
(
x_fp32
);
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
}
// Convert X to Y
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
type_convert_sp
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
return
static_cast
<
Y
>
(
x
);
}
template
<
>
inline
__host__
__device__
constexpr
int
type_convert_sp
<
int
,
float
>
(
float
x
)
{
union
{
float
fp32
;
int
int32
;
}
u
=
{
x
};
return
u
.
int32
;
}
template
<
>
inline
__host__
__device__
constexpr
float
type_convert_sp
<
float
,
int
>
(
int
x
)
{
union
{
int
int32
;
float
fp32
;
}
u
=
{
x
};
return
u
.
fp32
;
}
template
<
>
inline
__host__
__device__
constexpr
int
type_convert_sp
<
int
,
half_t
>
(
half_t
x
)
{
union
{
half_t
fp16
;
int
int32
;
}
u
=
{
x
};
return
u
.
int32
;
}
template
<
>
inline
__host__
__device__
constexpr
half_t
type_convert_sp
<
half_t
,
int
>
(
int
x
)
{
union
{
int
int32
;
half_t
fp16
;
}
u
=
{
x
};
return
u
.
fp16
;
}
// Declare a template function for fp8 conversion using SR
// Declare a template function for fp8 conversion using SR
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
f8_convert_sr
(
X
x
);
__host__
__device__
constexpr
Y
f8_convert_sr
(
X
x
);
...
...
include/ck/wrapper/layout.hpp
View file @
56de337f
...
@@ -5,8 +5,11 @@
...
@@ -5,8 +5,11 @@
#include "ck/wrapper/utils/layout_utils.hpp"
#include "ck/wrapper/utils/layout_utils.hpp"
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace
ck
{
namespace
ck
{
namespace
wrapper
{
namespace
wrapper
{
/// @endcond
/**
/**
* \brief Layout wrapper that performs the tensor descriptor logic.
* \brief Layout wrapper that performs the tensor descriptor logic.
...
@@ -19,6 +22,8 @@ namespace wrapper {
...
@@ -19,6 +22,8 @@ namespace wrapper {
template
<
typename
Shape
,
typename
UnrolledDescriptorType
>
template
<
typename
Shape
,
typename
UnrolledDescriptorType
>
struct
Layout
struct
Layout
{
{
// Disable from doxygen docs generation
/// @cond INTERNAL
private:
private:
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -246,6 +251,7 @@ struct Layout
...
@@ -246,6 +251,7 @@ struct Layout
using
Descriptor1dType
=
using
Descriptor1dType
=
remove_cvref_t
<
decltype
(
MakeMerge1d
(
Shape
{},
UnrolledDescriptorType
{}))
>
;
remove_cvref_t
<
decltype
(
MakeMerge1d
(
Shape
{},
UnrolledDescriptorType
{}))
>
;
using
DefaultIdxsTupleType
=
remove_cvref_t
<
decltype
(
GenerateDefaultIdxsTuple
(
Shape
{}))
>
;
using
DefaultIdxsTupleType
=
remove_cvref_t
<
decltype
(
GenerateDefaultIdxsTuple
(
Shape
{}))
>
;
/// @endcond
public:
public:
using
LayoutShape
=
Shape
;
using
LayoutShape
=
Shape
;
...
@@ -457,6 +463,8 @@ struct Layout
...
@@ -457,6 +463,8 @@ struct Layout
return
unrolled_descriptor_
;
return
unrolled_descriptor_
;
}
}
// Disable from doxygen docs generation
/// @cond INTERNAL
private:
private:
// All dimensions are unrolled
// All dimensions are unrolled
UnrolledDescriptorType
unrolled_descriptor_
;
UnrolledDescriptorType
unrolled_descriptor_
;
...
@@ -469,6 +477,7 @@ struct Layout
...
@@ -469,6 +477,7 @@ struct Layout
// Descriptor1dType lengths: (8)
// Descriptor1dType lengths: (8)
// MergedNestsDescriptorType lengths: (4, 2)
// MergedNestsDescriptorType lengths: (4, 2)
const
Shape
shape_
;
const
Shape
shape_
;
/// @endcond
};
};
}
// namespace wrapper
}
// namespace wrapper
...
...
include/ck/wrapper/operations/copy.hpp
View file @
56de337f
...
@@ -12,8 +12,11 @@
...
@@ -12,8 +12,11 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace
ck
{
namespace
ck
{
namespace
wrapper
{
namespace
wrapper
{
/// @endcond
/**
/**
* \brief Perform optimized copy between two tensors partitions (threadwise copy).
* \brief Perform optimized copy between two tensors partitions (threadwise copy).
...
...
include/ck/wrapper/operations/gemm.hpp
View file @
56de337f
...
@@ -9,9 +9,14 @@
...
@@ -9,9 +9,14 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace
ck
{
namespace
ck
{
namespace
wrapper
{
namespace
wrapper
{
/// @endcond
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace
{
namespace
{
namespace
detail
{
namespace
detail
{
/**
/**
...
@@ -45,6 +50,7 @@ __device__ constexpr auto GetBlockDescriptor()
...
@@ -45,6 +50,7 @@ __device__ constexpr auto GetBlockDescriptor()
}
// namespace detail
}
// namespace detail
}
// namespace
}
// namespace
/// @endcond
/**
/**
* \brief Perform blockwise gemm xdl on tensors stored in lds. Result will be
* \brief Perform blockwise gemm xdl on tensors stored in lds. Result will be
...
...
include/ck/wrapper/tensor.hpp
View file @
56de337f
...
@@ -7,9 +7,14 @@
...
@@ -7,9 +7,14 @@
#include "utils/tensor_partition.hpp"
#include "utils/tensor_partition.hpp"
#include "utils/layout_utils.hpp"
#include "utils/layout_utils.hpp"
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace
ck
{
namespace
ck
{
namespace
wrapper
{
namespace
wrapper
{
/// @endcond
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace
{
namespace
{
namespace
detail
{
namespace
detail
{
/**
/**
...
@@ -189,6 +194,7 @@ __host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>&
...
@@ -189,6 +194,7 @@ __host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>&
}
}
}
// namespace detail
}
// namespace detail
}
// namespace
}
// namespace
/// @endcond
/**
/**
* \brief Tensor wrapper that performs static and dynamic buffer logic.
* \brief Tensor wrapper that performs static and dynamic buffer logic.
...
@@ -394,6 +400,8 @@ struct Tensor
...
@@ -394,6 +400,8 @@ struct Tensor
}
}
private:
private:
// Disable from doxygen docs generation
/// @cond INTERNAL
using
DynamicBufferType
=
DynamicBuffer
<
BufferAddressSpace
,
using
DynamicBufferType
=
DynamicBuffer
<
BufferAddressSpace
,
ElementType
,
ElementType
,
ElementSpaceSize
,
ElementSpaceSize
,
...
@@ -428,6 +436,7 @@ struct Tensor
...
@@ -428,6 +436,7 @@ struct Tensor
// tensor descriptor (thus all it's transforms) and is linear (1D).
// tensor descriptor (thus all it's transforms) and is linear (1D).
// We store base_offset_ to avoid multiple recalculations.
// We store base_offset_ to avoid multiple recalculations.
index_t
base_offset_
;
index_t
base_offset_
;
/// @endcond
};
};
}
// namespace wrapper
}
// namespace wrapper
...
...
include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp
View file @
56de337f
...
@@ -5,8 +5,11 @@
...
@@ -5,8 +5,11 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace
ck
{
namespace
ck
{
namespace
wrapper
{
namespace
wrapper
{
/// @endcond
/**
/**
* \brief Traits for blockwise gemm xdl.
* \brief Traits for blockwise gemm xdl.
...
...
include/ck/wrapper/utils/kernel_utils.hpp
View file @
56de337f
...
@@ -5,8 +5,11 @@
...
@@ -5,8 +5,11 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace
ck
{
namespace
ck
{
namespace
wrapper
{
namespace
wrapper
{
/// @endcond
#define __CK_WRAPPER_LAUNCH_BOUNDS__ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#define __CK_WRAPPER_LAUNCH_BOUNDS__ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...
...
include/ck/wrapper/utils/layout_utils.hpp
View file @
56de337f
...
@@ -17,11 +17,14 @@
...
@@ -17,11 +17,14 @@
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace
ck
{
namespace
ck
{
namespace
wrapper
{
namespace
wrapper
{
/// @endcond
// Disable from doxygen docs generation
// Disable from doxygen docs generation
/// @cond
/// @cond
INTERNAL
// forward declaration
// forward declaration
template
<
typename
Shape
,
typename
UnrolledDescriptorType
>
template
<
typename
Shape
,
typename
UnrolledDescriptorType
>
struct
Layout
;
struct
Layout
;
...
...
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment