Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
159c40c6
Unverified
Commit
159c40c6
authored
Mar 15, 2023
by
zjing14
Committed by
GitHub
Mar 15, 2023
Browse files
Merge branch 'develop' into lwpck-570
parents
bb20f88f
14b3504d
Changes
35
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
171 additions
and
31 deletions
+171
-31
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
...e/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
+5
-5
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
...or_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
+9
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp
...on/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp
+9
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
...tion/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
+10
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
...vice_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
+9
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
...impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
+2
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp
...e_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp
+11
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
...e_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
+9
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+13
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
.../impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
+9
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
+4
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+1
-2
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+4
-4
include/ck/utility/amd_inline_asm.hpp
include/ck/utility/amd_inline_asm.hpp
+12
-6
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+64
-5
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
View file @
159c40c6
...
@@ -77,8 +77,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -77,8 +77,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
{
{
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
const
index_t
K0
=
K
/
K1
;
const
auto
a_grid_desc_m_k
=
[
&
]()
{
const
auto
a_grid_desc_m_k
=
[
&
]()
{
...
@@ -116,8 +114,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -116,8 +114,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
{
{
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
const
index_t
K0
=
K
/
K1
;
const
auto
b_grid_desc_k_n
=
[
&
]()
{
const
auto
b_grid_desc_k_n
=
[
&
]()
{
...
@@ -551,7 +547,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -551,7 +547,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
<<
MPerXDL
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
ABlockTransferDstScalarPerVector_K1
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferDstScalarPerVector_K1
<<
">"
<<
">"
<<
" NumPrefetch: "
<<
" NumPrefetch: "
<<
NumPrefetch
<<
", "
<<
NumPrefetch
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
View file @
159c40c6
...
@@ -682,7 +682,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -682,7 +682,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
<<
NPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
">"
<<
">"
<<
" LoopScheduler: "
<<
" LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp
View file @
159c40c6
...
@@ -760,7 +760,15 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
...
@@ -760,7 +760,15 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
<<
NPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
">"
;
<<
">"
;
// clang-format on
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
159c40c6
...
@@ -640,7 +640,16 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -640,7 +640,16 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
K0PerBlock
<<
", "
<<
K1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
ABlockTransferDstScalarPerVector_K1
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferDstScalarPerVector_K1
<<
">"
;
<<
">"
;
// clang-format on
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
View file @
159c40c6
...
@@ -1003,7 +1003,15 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
...
@@ -1003,7 +1003,15 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<<
KPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
BK1
<<
", "
<<
getConvBackwardDataSpecializationString
(
ConvBackwardDataSpecialization
)
<<
getConvBackwardDataSpecializationString
(
ConvBackwardDataSpecialization
)
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
">"
;
<<
">"
;
return
str
.
str
();
return
str
.
str
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
View file @
159c40c6
...
@@ -1203,7 +1203,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -1203,7 +1203,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
getConvBackwardWeightSpecializationString
(
ConvBackwardWeightSpecialization
)
<<
getConvBackwardWeightSpecializationString
(
ConvBackwardWeightSpecialization
)
<<
", "
<<
K1
<<
">"
;
<<
">"
;
// clang-format on
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp
View file @
159c40c6
...
@@ -1231,7 +1231,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
...
@@ -1231,7 +1231,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
getConvBackwardWeightSpecializationString
(
ConvBackwardWeightSpecialization
)
<<
getConvBackwardWeightSpecializationString
(
ConvBackwardWeightSpecialization
)
<<
", "
<<
K1
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
ABlockTransferDstScalarPerVector_K1
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferDstScalarPerVector_K1
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
CBlockTransferScalarPerVector_NWaveNPerXdl
<<
">"
;
<<
">"
;
// clang-format on
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
View file @
159c40c6
...
@@ -1092,7 +1092,15 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
...
@@ -1092,7 +1092,15 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
getConvForwardSpecializationString
(
ConvForwardSpecialization
)
<<
getConvForwardSpecializationString
(
ConvForwardSpecialization
)
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
">"
;
<<
">"
;
// clang-format on
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
View file @
159c40c6
...
@@ -838,7 +838,19 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -838,7 +838,19 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
getConvForwardSpecializationString
(
ConvForwardSpecialization
)
<<
getConvForwardSpecializationString
(
ConvForwardSpecialization
)
<<
", "
<<
K1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
ABlockTransferDstScalarPerVector_K1
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferDstScalarPerVector_K1
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
CBlockTransferScalarPerVector_NWaveNPerXdl
<<
">"
;
<<
">"
;
// clang-format on
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
View file @
159c40c6
...
@@ -939,7 +939,15 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
...
@@ -939,7 +939,15 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
getConvForwardSpecializationString
(
ConvForwardSpecialization
)
<<
getConvForwardSpecializationString
(
ConvForwardSpecialization
)
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
">"
;
<<
">"
;
// clang-format on
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
View file @
159c40c6
...
@@ -688,6 +688,10 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
...
@@ -688,6 +688,10 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
<<
NPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
getGemmSpecializationString
(
GemmSpec
)
<<
">"
;
<<
">"
;
// clang-format on
// clang-format on
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
159c40c6
...
@@ -676,7 +676,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
...
@@ -676,7 +676,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
auto
blockwise_gemm
=
auto
blockwise_gemm
=
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
_FIFO
<
BlockSize
,
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
<
BlockSize
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
AccDataType
,
AccDataType
,
...
@@ -719,7 +719,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
...
@@ -719,7 +719,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
c_thread_buf
,
c_thread_buf
,
K0BlockMainLoop
);
K0BlockMainLoop
);
/*******************************************************************************/
/*******************************************************************************/
//printf("safe 1");
// write out to C, implement shuffle
// write out to C, implement shuffle
{
{
constexpr
auto
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
constexpr
auto
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
159c40c6
...
@@ -1030,7 +1030,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
...
@@ -1030,7 +1030,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_element_valid
?
0
:
0x
7fffffff
;
uint32_t
src_addr_shift
=
src_thread_element_valid
?
0
:
0x
80000000
;
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
>
(
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
...
@@ -1091,7 +1091,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
...
@@ -1091,7 +1091,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x
7fffffff
;
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x
80000000
;
amd_buffer_store_impl
<
scalar_t
,
vector_size
>
(
amd_buffer_store_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
...
@@ -1126,7 +1126,7 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
...
@@ -1126,7 +1126,7 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x
7fffffff
;
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x
80000000
;
amd_buffer_atomic_add_impl
<
scalar_t
,
vector_size
>
(
amd_buffer_atomic_add_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
...
@@ -1161,7 +1161,7 @@ amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thr
...
@@ -1161,7 +1161,7 @@ amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thr
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x
7fffffff
;
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x
80000000
;
amd_buffer_atomic_max_impl
<
scalar_t
,
vector_size
>
(
amd_buffer_atomic_max_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
...
...
include/ck/utility/amd_inline_asm.hpp
View file @
159c40c6
...
@@ -358,7 +358,13 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
...
@@ -358,7 +358,13 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
// Ranged input operand
// Ranged input operand
__device__
void
amd_assembly_wmma_f32_16x16x16_f16_w32
(
half16_t
a
,
half16_t
b
,
float8_t
&
c
)
__device__
void
amd_assembly_wmma_f32_16x16x16_f16_w32
(
half16_t
a
,
half16_t
b
,
float8_t
&
c
)
{
{
#if defined(__gfx11__)
asm
volatile
(
"v_wmma_f32_16x16x16_f16 %0, %1, %2, %0"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
asm
volatile
(
"v_wmma_f32_16x16x16_f16 %0, %1, %2, %0"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
#else
ignore
=
a
;
ignore
=
b
;
ignore
=
c
;
#endif
}
}
}
// namespace ck
}
// namespace ck
...
...
include/ck/utility/amd_wmma.hpp
View file @
159c40c6
...
@@ -23,11 +23,16 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
...
@@ -23,11 +23,16 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
{
{
// * Inline assembly need to elimate the duplicated data load, compiler won't help you
// * Inline assembly need to elimate the duplicated data load, compiler won't help you
// delete them.
// delete them.
amd_assembly_wmma_f32_16x16x16_f16_w32
(
// amd_assembly_wmma_f32_16x16x16_f16_w32(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{}));
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
// reg_c.template AsType<float8_t>()(Number<0>{}) =
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
// __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( reg_a, reg_b, reg_c.template
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32
(
// AsType<float8_t>()[Number<0>{}]);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
}
};
};
...
@@ -41,9 +46,15 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16>
...
@@ -41,9 +46,15 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16>
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32
(
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
}
};
};
...
@@ -60,8 +71,14 @@ struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel>
...
@@ -60,8 +71,14 @@ struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel>
// opsel usage
// opsel usage
// false: D0.[0:15] = result
// false: D0.[0:15] = result
// true : D0.[16:31]= result
// true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c
.
template
AsType
<
half16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f16_16x16x16_f16_w32
(
reg_c
.
template
AsType
<
half16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f16_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
half16_t
>()[
Number
<
0
>
{}],
Opsel
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
half16_t
>()[
Number
<
0
>
{}],
Opsel
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
}
};
};
...
@@ -78,9 +95,15 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
...
@@ -78,9 +95,15 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
// opsel usage
// opsel usage
// false: D0.[0:15] = result
// false: D0.[0:15] = result
// true : D0.[16:31]= result
// true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c
.
template
AsType
<
bhalf16_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
bhalf16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32
(
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
bhalf16_t
>()[
Number
<
0
>
{}],
Opsel
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
bhalf16_t
>()[
Number
<
0
>
{}],
Opsel
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
}
};
};
...
@@ -94,6 +117,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
...
@@ -94,6 +117,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c
.
template
AsType
<
int32x8_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
int32x8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32
(
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32
(
neg_a
,
neg_a
,
...
@@ -102,6 +126,11 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
...
@@ -102,6 +126,11 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
bit_cast
<
int32x4_t
>
(
reg_b
),
bit_cast
<
int32x4_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x8_t
>()[
Number
<
0
>
{}],
reg_c
.
template
AsType
<
int32x8_t
>()[
Number
<
0
>
{}],
clamp
);
clamp
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
}
};
};
...
@@ -116,8 +145,14 @@ struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
...
@@ -116,8 +145,14 @@ struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w64
(
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}]);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}]);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
}
};
};
...
@@ -131,9 +166,15 @@ struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16>
...
@@ -131,9 +166,15 @@ struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16>
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64
(
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}]);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}]);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
}
};
};
...
@@ -150,8 +191,14 @@ struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel>
...
@@ -150,8 +191,14 @@ struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel>
// opsel usage
// opsel usage
// false: D0.[0:15] = result
// false: D0.[0:15] = result
// true : D0.[16:31]= result
// true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c
.
template
AsType
<
half8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f16_16x16x16_f16_w64
(
reg_c
.
template
AsType
<
half8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f16_16x16x16_f16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}],
Opsel
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}],
Opsel
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
}
};
};
...
@@ -168,9 +215,15 @@ struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel>
...
@@ -168,9 +215,15 @@ struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel>
// opsel usage
// opsel usage
// false: D0.[0:15] = result
// false: D0.[0:15] = result
// true : D0.[16:31]= result
// true : D0.[16:31]= result
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c
.
template
AsType
<
bhalf8_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
bhalf8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64
(
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
bhalf8_t
>()[
Number
<
0
>
{}],
Opsel
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
bhalf8_t
>()[
Number
<
0
>
{}],
Opsel
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
}
};
};
...
@@ -184,6 +237,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
...
@@ -184,6 +237,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64
(
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64
(
neg_a
,
neg_a
,
...
@@ -192,6 +246,11 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
...
@@ -192,6 +246,11 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
bit_cast
<
int32x4_t
>
(
reg_b
),
bit_cast
<
int32x4_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
clamp
);
clamp
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
}
};
};
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment