Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e2878e25
Commit
e2878e25
authored
May 17, 2023
by
Alan Turner
Browse files
Merge remote-tracking branch 'origin/develop' into migx-jit-lib
parents
1ec96717
642d5e91
Changes
105
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
844 additions
and
46 deletions
+844
-46
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
...pu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
...eration/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
...tion/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
...or_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
+3
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp
...operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+3
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+80
-10
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp
...d/normalization/gridwise_normalization_naive_variance.hpp
+0
-0
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp
...pu/grid/normalization/gridwise_normalization_selector.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp
.../grid/normalization/gridwise_normalization_splitk_1st.hpp
+252
-0
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp
.../grid/normalization/gridwise_normalization_splitk_2nd.hpp
+418
-0
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp
...normalization/gridwise_normalization_welford_variance.hpp
+0
-0
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+8
-19
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+59
-1
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
View file @
e2878e25
...
@@ -66,7 +66,8 @@ __global__ void
...
@@ -66,7 +66,8 @@ __global__ void
const
ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock
,
const
ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
e2878e25
...
@@ -96,7 +96,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -96,7 +96,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// when mfma if fixed, remove this section and update
// ABDataTypeAdjusted -> ABDataType throughout this file
// ABDataTypeAdjusted -> ABDataType throughout this file
#if CK_WORKAROUND_DENORM_FIX
&& defined(__gfx90a__)
#if CK_WORKAROUND_DENORM_FIX
using
ABDataTypeAdjusted
=
using
ABDataTypeAdjusted
=
conditional_t
<
is_same_v
<
ABDataType
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
ABDataType
>
;
conditional_t
<
is_same_v
<
ABDataType
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
ABDataType
>
;
#else
#else
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
View file @
e2878e25
...
@@ -54,7 +54,8 @@ __global__ void
...
@@ -54,7 +54,8 @@ __global__ void
const
ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock
,
const
ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
e2878e25
...
@@ -44,7 +44,8 @@ __global__ void
...
@@ -44,7 +44,8 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
View file @
e2878e25
...
@@ -57,7 +57,8 @@ __global__ void
...
@@ -57,7 +57,8 @@ __global__ void
const
C0GridDescriptor_NBlock_NPerBlock
c0_grid_desc_nblock_nperblock
,
const
C0GridDescriptor_NBlock_NPerBlock
c0_grid_desc_nblock_nperblock
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
// TODO ANT: separate into MMA + Epilogue
// TODO ANT: separate into MMA + Epilogue
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
View file @
e2878e25
...
@@ -165,7 +165,8 @@ __global__ void
...
@@ -165,7 +165,8 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
@@ -265,7 +266,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
...
@@ -265,7 +266,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file
// FloatABAdjusted -> FloatAB throughout this file
#if CK_WORKAROUND_DENORM_FIX
&& defined(__gfx90a__)
#if CK_WORKAROUND_DENORM_FIX
using
FloatABAdjusted
=
conditional_t
<
is_same_v
<
FloatAB
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
FloatAB
>
;
using
FloatABAdjusted
=
conditional_t
<
is_same_v
<
FloatAB
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
FloatAB
>
;
#else
#else
using
FloatABAdjusted
=
FloatAB
;
using
FloatABAdjusted
=
FloatAB
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp
View file @
e2878e25
...
@@ -44,7 +44,8 @@ __global__ void
...
@@ -44,7 +44,8 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
e2878e25
...
@@ -43,7 +43,8 @@ __global__ void
...
@@ -43,7 +43,8 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
@@ -135,7 +136,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -135,7 +136,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file
// FloatABAdjusted -> FloatAB throughout this file
#if CK_WORKAROUND_DENORM_FIX
&& defined(__gfx90a__)
#if CK_WORKAROUND_DENORM_FIX
using
FloatABAdjusted
=
conditional_t
<
is_same_v
<
FloatAB
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
FloatAB
>
;
using
FloatABAdjusted
=
conditional_t
<
is_same_v
<
FloatAB
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
FloatAB
>
;
#else
#else
using
FloatABAdjusted
=
FloatAB
;
using
FloatABAdjusted
=
FloatAB
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
View file @
e2878e25
...
@@ -42,7 +42,8 @@ __global__ void
...
@@ -42,7 +42,8 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
constexpr
index_t
shared_block_size
=
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
e2878e25
...
@@ -15,26 +15,32 @@
...
@@ -15,26 +15,32 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
Block2CTileMap
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_xdlops_v2r4r2_simplified
(
typename
GridwiseGemm
::
Argument
karg
)
kernel_gemm_xdlops_v2r4r2_simplified
(
typename
GridwiseGemm
::
Argument
karg
,
const
Block2CTileMap
&
b2c_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
__shared__
uint8_t
p_shared
[
shared_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
karg
,
static_cast
<
void
*>
(
p_shared
));
karg
,
static_cast
<
void
*>
(
p_shared
)
,
b2c_map
);
#else
#else
ignore
=
karg
;
ignore
=
karg
;
ignore
=
b2c_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
...
@@ -478,8 +484,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -478,8 +484,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
Number
<
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
{}));
Number
<
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
{}));
}
}
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
// return block_id to C matrix tile idx (m0, n0, k_split) mapping
__device__
static
void
Run
(
const
Argument
&
karg
,
void
*
__restrict__
p_shared_block
)
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
()
{
return
BlockToCTileMap_3DGrid_KSplit
<
MPerBlock
,
NPerBlock
>
();
}
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
,
1
,
1
))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
())
>
;
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
Argument
&
karg
,
void
*
__restrict__
p_shared_block
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
{
const
FloatAB
*
p_a_grid
=
karg
.
p_a_grid
;
const
FloatAB
*
p_a_grid
=
karg
.
p_a_grid
;
const
FloatAB
*
p_b_grid
=
karg
.
p_b_grid
;
const
FloatAB
*
p_b_grid
=
karg
.
p_b_grid
;
...
@@ -504,11 +523,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -504,11 +523,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
auto
K0
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
);
// divide block work by [KBatch, M, N]
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
const
index_t
block_m_id
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
);
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
const
index_t
block_n_id
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
);
block_work_idx
,
const
index_t
k_batch_id
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
make_tuple
(
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
const
index_t
block_m_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]);
const
index_t
block_n_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I2
]);
const
index_t
k_batch_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]);
// HACK: this force m/n_block_data_idx_on_grid into SGPR
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
const
index_t
m_block_data_idx_on_grid
=
...
@@ -651,6 +680,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -651,6 +680,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// register
// sanity check
// sanity check
#if 1
auto
blockwise_gemm
=
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAB
,
...
@@ -662,6 +692,20 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -662,6 +692,20 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
K1
>
{};
K1
>
{};
#else
auto
blockwise_gemm
=
BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
K1
>
{};
#endif
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
...
@@ -680,6 +724,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -680,6 +724,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_b_block
,
b_k0_n_k1_block_desc
.
GetElementSpaceSize
());
p_b_block
,
b_k0_n_k1_block_desc
.
GetElementSpaceSize
());
#if 0
// preload data into LDS
// preload data into LDS
{
{
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
...
@@ -725,6 +770,31 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -725,6 +770,31 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
#else
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_Selector
<
PipelineVersion
::
v2
,
1
,
LoopScheduler
::
Default
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
)
*
a_b_k0_m_k1_grid_desc
.
GetLength
(
I3
))
/
(
K0PerBlock
*
K1
));
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
a_b_k0_m_k1_grid_desc
,
a_b_k0_m_k1_block_desc
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_b_k0_n_k1_grid_desc
,
b_b_k0_n_k1_block_desc
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
num_k_block_main_loop
);
#endif
// output: register to global memory
// output: register to global memory
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
View file @
e2878e25
...
@@ -46,7 +46,8 @@ __global__ void
...
@@ -46,7 +46,8 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
View file @
e2878e25
...
@@ -49,7 +49,8 @@ __global__ void
...
@@ -49,7 +49,8 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
View file @
e2878e25
...
@@ -53,7 +53,8 @@ __global__ void
...
@@ -53,7 +53,8 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp
→
include/ck/tensor_operation/gpu/grid/
normalization/
gridwise_normalization_naive_variance.hpp
View file @
e2878e25
File moved
include/ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp
→
include/ck/tensor_operation/gpu/grid/
normalization/
gridwise_normalization_selector.hpp
View file @
e2878e25
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
#pragma once
#pragma once
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp"
#include "ck/tensor_operation/gpu/grid/
normalization/
gridwise_normalization_naive_variance.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp"
#include "ck/tensor_operation/gpu/grid/
normalization/
gridwise_normalization_welford_variance.hpp"
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseReduction
,
template
<
typename
GridwiseReduction
,
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp
0 → 100644
View file @
e2878e25
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
XDataType
,
typename
ComputeDataType
,
typename
MeanVarDataType
,
typename
XGridDesc_M_K
,
typename
MeanVarGridDesc_M_KBlock
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
>
struct
GridwiseNormalizationSplitK1st
{
static_assert
((
XSrcVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
XSrcVectorDim
==
0
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
XSrcVectorSize
>
;
static
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{}));
using
ThreadBufferLengths_M_1
=
Sequence
<
MThreadSliceSize
,
1
>
;
static
constexpr
auto
thread_buffer_desc_m_1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
I1
));
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
ThreadwiseWelford
=
ThreadwiseWelford
<
ComputeDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
>
;
using
BlockwiseWelford
=
BlockwiseWelford
<
ComputeDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
false
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileStepSize
=
KThreadClusterSize
*
XSrcVectorSize
;
static
constexpr
auto
ThreadBufferNumber
=
Number
<
KThreadSliceSize
/
XSrcVectorSize
>
{};
__device__
static
int
GetKPerThread
(
int
kRaw
,
int
kGridSize
,
int
block_k_cluster_id
,
int
thread_k_cluster_id
)
{
bool
is_rightmost_block
=
block_k_cluster_id
==
kGridSize
-
1
;
if
(
is_rightmost_block
)
{
int
left_kPerBlock
=
math
::
integer_divide_ceil
(
kRaw
,
kGridSize
);
int
kPerBlock
=
kRaw
%
kGridSize
==
0
?
left_kPerBlock
:
kRaw
%
left_kPerBlock
;
int
kPerThread
=
kPerBlock
<
K_BlockTileSize
?
0
:
KThreadSliceSize
*
(
kPerBlock
/
K_BlockTileSize
);
int
kPerBlockTail
=
kPerBlock
-
kPerThread
*
KThreadClusterSize
;
if
(
kPerBlockTail
>
0
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
int
thread_max_len
=
(
thread_k_cluster_id
+
1
)
*
XSrcVectorSize
+
K_BlockTileStepSize
*
i
;
int
delta
=
thread_max_len
-
kPerBlockTail
;
delta
=
math
::
clamp
(
thread_max_len
-
kPerBlockTail
,
0
,
XSrcVectorSize
);
kPerThread
+=
XSrcVectorSize
-
delta
;
});
}
return
kPerThread
;
}
else
{
int
kPerBlock
=
math
::
integer_divide_ceil
(
kRaw
,
kGridSize
);
return
KThreadSliceSize
*
(
kPerBlock
/
K_BlockTileSize
);
}
}
// Calculate mean and variance by welford along k dimension
__device__
static
void
Run
(
const
XGridDesc_M_K
&
x_grid_desc_m_k
,
const
MeanVarGridDesc_M_KBlock
&
mean_var_grid_desc_m_kblock
,
index_t
num_k_block_tile_iteration
,
const
XDataType
*
const
__restrict__
p_x_global
,
MeanVarDataType
*
const
p_mean_global
,
MeanVarDataType
*
const
p_variance_global
,
int32_t
*
const
p_welford_count_global
)
{
auto
x_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
XSrcVectorSize
,
true
>
{};
},
Number
<
ThreadBufferNumber
>
{});
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
var_thread_buf
;
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
k_grid_size
=
mean_var_grid_desc_m_kblock
.
GetLength
(
I1
);
const
index_t
block_m_cluster_id
=
block_global_id
/
k_grid_size
;
const
index_t
block_k_cluster_id
=
block_global_id
%
k_grid_size
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
const
index_t
reduceSizePerBlock
=
K_BlockTileSize
*
num_k_block_tile_iteration
;
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
ComputeDataType
,
XGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcVectorDim
,
XSrcVectorSize
,
1
,
true
>
(
x_grid_desc_m_k
,
make_multi_index
(
block_m_cluster_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_k_cluster_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
XSrcVectorSize
));
auto
mean_var_count_store_index
=
make_multi_index
(
block_m_cluster_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_k_cluster_id
);
auto
threadwise_welford_mean_var_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
MeanVarDataType
,
decltype
(
thread_buffer_desc_m_1
),
MeanVarGridDesc_M_KBlock
,
PassThroughOp
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_grid_desc_m_kblock
,
mean_var_count_store_index
,
PassThroughOp
{});
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileStepSize
);
const
auto
x_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
auto
mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_mean_global
,
mean_var_grid_desc_m_kblock
.
GetElementSpaceSize
());
auto
var_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_variance_global
,
mean_var_grid_desc_m_kblock
.
GetElementSpaceSize
());
auto
threadwise_welford
=
ThreadwiseWelford
();
int
kRaw
=
x_grid_desc_m_k
.
GetTransforms
()[
I2
].
GetUpperLengths
()[
I0
];
threadwise_welford
.
max_count_
=
GetKPerThread
(
kRaw
,
k_grid_size
,
block_k_cluster_id
,
thread_k_cluster_id
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
mean_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
var_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
});
for
(
index_t
k
=
0
;
k
<
num_k_block_tile_iteration
;
++
k
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
(
i
));
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_welford
.
Run
(
x_thread_buf
[
i
],
mean_thread_buf
,
var_thread_buf
);
});
}
int
welford_count
=
0
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
int
count
=
threadwise_welford
.
cur_count_
;
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count
);
// The value of count is same for all I
if
constexpr
(
I
==
MThreadSliceSize
-
1
)
welford_count
=
count
;
});
if
(
thread_k_cluster_id
==
0
)
{
threadwise_welford_mean_var_store
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
mean_thread_buf
,
mean_var_grid_desc_m_kblock
,
mean_global_val_buf
);
threadwise_welford_mean_var_store
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
var_thread_buf
,
mean_var_grid_desc_m_kblock
,
var_global_val_buf
);
if
(
block_m_cluster_id
==
0
&&
thread_m_cluster_id
==
0
)
p_welford_count_global
[
block_k_cluster_id
]
=
welford_count
;
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp
0 → 100644
View file @
e2878e25
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
MeanVarDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
ComputeDataType
,
typename
YElementwiseOperation
,
typename
MeanVarGridDesc_M_KBlock
,
typename
CountGridDesc_M_KBlock
,
typename
XYGammaBetaGridDesc_M_K
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
>
struct
GridwiseNormalizationSplitK2nd
{
static_assert
((
XSrcVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
((
YDstVectorDim
==
0
&&
MThreadSliceSize
%
YDstVectorSize
==
0
)
||
(
YDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
(
XSrcVectorSize
==
YDstVectorSize
);
static_assert
(
XSrcVectorSize
==
GammaSrcVectorSize
);
static_assert
(
XSrcVectorSize
==
BetaSrcVectorSize
);
static
constexpr
bool
reorder_thread_cluster
=
(
XSrcVectorDim
==
0
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
XSrcVectorSize
>
;
static
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{}));
using
ThreadBufferLengths_M_1
=
Sequence
<
MThreadSliceSize
,
1
>
;
static
constexpr
auto
thread_buffer_desc_m_1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
I1
));
using
ThreadWelfordSrcDesc_M_1
=
decltype
(
thread_buffer_desc_m_1
);
using
ThreadWelfordDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
ThreadwiseWelford
=
ThreadwiseWelfordMerge
<
ComputeDataType
,
ThreadWelfordSrcDesc_M_1
,
ThreadWelfordDstDesc_M
>
;
using
BlockwiseWelford
=
BlockwiseWelford
<
ComputeDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileStepSize
=
KThreadClusterSize
*
XSrcVectorSize
;
static
constexpr
auto
ThreadBufferNumber
=
Number
<
KThreadSliceSize
/
XSrcVectorSize
>
{};
__device__
static
void
Run
(
const
MeanVarGridDesc_M_KBlock
&
mean_var_grid_desc_m_kblock
,
const
CountGridDesc_M_KBlock
&
count_grid_desc_m_kblock
,
const
XYGammaBetaGridDesc_M_K
&
x_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
&
gamma_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
&
beta_grid_desc_m_k
,
const
XYGammaBetaGridDesc_M_K
&
y_grid_desc_m_k
,
index_t
num_k_mean_var_count_iteration
,
index_t
num_k_block_tile_iteration
,
index_t
k_grid_size
,
ComputeDataType
epsilon
,
const
MeanVarDataType
*
const
p_mean_global
,
const
MeanVarDataType
*
const
p_variance_global
,
const
int32_t
*
const
p_welford_count_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
const
YElementwiseOperation
y_elementwise_op
)
{
// Thread/Block id
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
block_m_cluster_id
=
block_global_id
/
k_grid_size
;
const
index_t
block_k_cluster_id
=
block_global_id
%
k_grid_size
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
// Global Memory
const
auto
mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_mean_global
,
mean_var_grid_desc_m_kblock
.
GetElementSpaceSize
());
const
auto
var_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_variance_global
,
mean_var_grid_desc_m_kblock
.
GetElementSpaceSize
());
const
auto
welford_count_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_welford_count_global
,
count_grid_desc_m_kblock
.
GetElementSpaceSize
());
const
auto
x_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_gamma_global
,
gamma_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_beta_global
,
beta_grid_desc_m_k
.
GetElementSpaceSize
());
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
// VGPR
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
in_mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
in_var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
in_welford_count_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
welford_count_thread_buf
;
auto
x_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
XSrcVectorSize
,
true
>
{};
},
Number
<
ThreadBufferNumber
>
{});
auto
gamma_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
GammaSrcVectorSize
,
true
>
{};
},
Number
<
ThreadBufferNumber
>
{});
auto
&
beta_thread_buf
=
gamma_thread_buf
;
auto
&
y_thread_buf
=
x_thread_buf
;
// IO
auto
threadwise_mean_var_load_m_kblock
=
ThreadwiseTensorSliceTransfer_v2
<
MeanVarDataType
,
ComputeDataType
,
MeanVarGridDesc_M_KBlock
,
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
true
>
(
mean_var_grid_desc_m_kblock
,
make_multi_index
(
block_m_cluster_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
));
auto
threadwise_count_load_m_kblock
=
ThreadwiseTensorSliceTransfer_v2
<
int32_t
,
int32_t
,
CountGridDesc_M_KBlock
,
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
true
>
(
count_grid_desc_m_kblock
,
make_multi_index
(
block_m_cluster_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
));
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
ComputeDataType
,
XYGammaBetaGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcVectorDim
,
XSrcVectorSize
,
1
,
true
>
(
x_grid_desc_m_k
,
make_multi_index
(
block_m_cluster_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_k_cluster_id
*
K_BlockTileSize
*
num_k_block_tile_iteration
+
thread_k_cluster_id
*
XSrcVectorSize
));
auto
threadwise_gamma_load
=
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
ComputeDataType
,
XYGammaBetaGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
1
,
true
>
(
gamma_grid_desc_m_k
,
make_multi_index
(
block_m_cluster_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_k_cluster_id
*
K_BlockTileSize
*
num_k_block_tile_iteration
+
thread_k_cluster_id
*
GammaSrcVectorSize
));
auto
threadwise_beta_load
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
ComputeDataType
,
XYGammaBetaGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
1
,
true
>
(
beta_grid_desc_m_k
,
make_multi_index
(
block_m_cluster_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_k_cluster_id
*
K_BlockTileSize
*
num_k_block_tile_iteration
+
thread_k_cluster_id
*
BetaSrcVectorSize
));
auto
threadwise_y_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
YDataType
,
decltype
(
thread_buffer_desc_m_k
),
XYGammaBetaGridDesc_M_K
,
YElementwiseOperation
,
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
YDstVectorDim
,
YDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
y_grid_desc_m_k
,
make_multi_index
(
block_m_cluster_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_k_cluster_id
*
K_BlockTileSize
*
num_k_block_tile_iteration
+
thread_k_cluster_id
*
YDstVectorSize
),
y_elementwise_op
);
// step1: Merge mean and variance
constexpr
auto
mean_var_count_thread_copy_step_I0_k
=
make_multi_index
(
I0
,
KThreadClusterSize
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
mean_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
var_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
welford_count_thread_buf
(
I
)
=
0
;
});
for
(
index_t
k
=
0
;
k
<
num_k_mean_var_count_iteration
;
++
k
)
{
threadwise_mean_var_load_m_kblock
.
Run
(
mean_var_grid_desc_m_kblock
,
mean_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
in_mean_thread_buf
);
threadwise_mean_var_load_m_kblock
.
Run
(
mean_var_grid_desc_m_kblock
,
var_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
in_var_thread_buf
);
threadwise_count_load_m_kblock
.
Run
(
count_grid_desc_m_kblock
,
welford_count_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
in_welford_count_thread_buf
);
ThreadwiseWelford
::
Run
(
in_mean_thread_buf
,
in_var_thread_buf
,
in_welford_count_thread_buf
,
mean_thread_buf
,
var_thread_buf
,
welford_count_thread_buf
);
threadwise_mean_var_load_m_kblock
.
MoveSrcSliceWindow
(
mean_var_grid_desc_m_kblock
,
mean_var_count_thread_copy_step_I0_k
);
threadwise_count_load_m_kblock
.
MoveSrcSliceWindow
(
count_grid_desc_m_kblock
,
mean_var_count_thread_copy_step_I0_k
);
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
welford_count_thread_buf
(
I
));
});
// step2: normalization
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileStepSize
);
for
(
index_t
k
=
0
;
k
<
num_k_block_tile_iteration
;
++
k
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
(
i
));
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
gamma_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
gamma_thread_buf
(
i
));
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK1
));
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
// gamma
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{});
});
});
});
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_beta_load
.
Run
(
beta_grid_desc_m_k
,
beta_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
beta_thread_buf
(
i
));
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK1
));
// beta
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{});
});
});
});
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_y_store
.
Run
(
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
y_thread_buf
(
i
),
y_grid_desc_m_k
,
y_global_val_buf
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
}
// end for (normalization)
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp
→
include/ck/tensor_operation/gpu/grid/
normalization/
gridwise_normalization_welford_variance.hpp
View file @
e2878e25
File moved
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
e2878e25
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
#include "ck/tensor/static_tensor.hpp"
...
@@ -207,15 +208,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -207,15 +208,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
auto
src_vector_container
=
src_vector_type
{
auto
src_vector_container
=
src_vector_type
{
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
)};
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
)};
// apply SrcElementwiseOperation on src_vector_container
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
SrcData
src_v
;
src_element_op_
(
src_v
,
src_vector_container
.
template
AsType
<
SrcData
>()[
i
]);
src_vector_container
.
template
AsType
<
SrcData
>()(
i
)
=
src_v
;
});
// copy data from src_vector_container into src_thread_scratch_
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_tuple_
(
thread_scratch_id
)
src_thread_scratch_tuple_
(
thread_scratch_id
)
.
template
SetAsType
<
src_vector_t
>(
.
template
SetAsType
<
src_vector_t
>(
...
@@ -318,7 +310,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -318,7 +310,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr
auto
data_idx_seq
=
generate_sequence_v2
(
constexpr
auto
data_idx_seq
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
return
Number
<
data_idx
[
i
]
>
{};
},
Number
<
nDim
>
{});
[
&
](
auto
i
)
{
return
Number
<
data_idx
[
i
]
>
{};
},
Number
<
nDim
>
{});
// TODO type_convert is not used yet!!!!!
using
src_vector_t
=
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
;
using
src_vector_t
=
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
;
using
dst_vector_t
=
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
;
using
dst_vector_t
=
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
;
...
@@ -342,19 +333,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -342,19 +333,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
Number
<
num_dst_vector
>
{});
Number
<
num_dst_vector
>
{});
// do data transpose
// do data transpose
// TODO type_convert is not used yet!!!!!
transpose_vectors
<
SrcData
,
DstScalarPerVector
,
SrcScalarPerVector
>
{}(
transpose_vectors
<
SrcData
,
DstScalarPerVector
,
SrcScalarPerVector
>
{}(
src_vector_refs
,
dst_vector_refs
);
src_vector_refs
,
dst_vector_refs
);
});
});
}
}
else
{
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
// apply the src elementwise op and convert to DstData under the hood if needed
// convert from SrcData to DstData here
DstData
dst_v
;
dst_thread_scratch_
(
idx
)
=
src_element_op_
(
dst_v
,
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
]);
type_convert
<
DstData
>
(
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
]);
dst_thread_scratch_
(
idx
)
=
dst_v
;
});
});
}
#endif
#endif
}
}
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
e2878e25
...
@@ -27,6 +27,8 @@ enum struct MfmaInstr
...
@@ -27,6 +27,8 @@ enum struct MfmaInstr
mfma_f32_16x16x8bf16
,
mfma_f32_16x16x8bf16
,
mfma_i32_32x32x8i8
,
mfma_i32_32x32x8i8
,
mfma_i32_16x16x16i8
,
mfma_i32_16x16x16i8
,
mfma_i32_32x32x16i8
,
mfma_i32_16x16x32i8
,
mfma_f64_16x16x4f64
mfma_f64_16x16x4f64
};
};
...
@@ -386,6 +388,50 @@ struct mfma_type<MfmaInstr::mfma_i32_16x16x16i8>
...
@@ -386,6 +388,50 @@ struct mfma_type<MfmaInstr::mfma_i32_16x16x16i8>
}
}
};
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_i32_32x32x16i8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_i32_32x32x16i8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_i32_16x16x32i8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_i32_16x16x32i8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f64_16x16x4f64
>
struct
mfma_type
<
MfmaInstr
::
mfma_f64_16x16x4f64
>
{
{
...
@@ -524,17 +570,29 @@ struct MfmaSelector
...
@@ -524,17 +570,29 @@ struct MfmaSelector
#endif
#endif
}
}
#if defined(CK_USE_AMD_MFMA_GFX940)
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_i32_32x32x16i8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_i32_16x16x32i8
;
}
#else
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
static
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
{
return
MfmaInstr
::
mfma_i32_32x32x8i8
;
return
MfmaInstr
::
mfma_i32_32x32x8i8
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
static
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
{
{
return
MfmaInstr
::
mfma_i32_16x16x16i8
;
return
MfmaInstr
::
mfma_i32_16x16x16i8
;
}
}
#endif
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
>
()
>
{};
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
>
()
>
{};
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment