Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6e3cf8b0
Commit
6e3cf8b0
authored
May 24, 2022
by
Jing Zhang
Browse files
merge develop
parents
4ad62d7f
ba58a93f
Changes
177
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
677 additions
and
401 deletions
+677
-401
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+16
-51
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
+16
-52
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+16
-52
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
+19
-51
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
+20
-51
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
+19
-51
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+108
-0
include/ck/utility/common_header.hpp
include/ck/utility/common_header.hpp
+1
-1
include/ck/utility/dynamic_buffer.hpp
include/ck/utility/dynamic_buffer.hpp
+41
-1
include/ck/utility/generic_memory_space_atomic.hpp
include/ck/utility/generic_memory_space_atomic.hpp
+97
-0
include/ck/utility/get_id.hpp
include/ck/utility/get_id.hpp
+4
-0
include/ck/utility/statically_indexed_array_multi_index.hpp
include/ck/utility/statically_indexed_array_multi_index.hpp
+7
-0
include/ck/utility/type.hpp
include/ck/utility/type.hpp
+3
-0
library/include/ck/library/host/host_interface.hpp
library/include/ck/library/host/host_interface.hpp
+54
-0
library/include/ck/library/host_tensor/device.hpp
library/include/ck/library/host_tensor/device.hpp
+75
-36
library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp
...reference_tensor_operation/cpu/reference_batched_gemm.hpp
+2
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp
...e_tensor_operation/cpu/reference_conv_backward_weight.hpp
+170
-48
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
...eference_tensor_operation/cpu/reference_conv_bwd_data.hpp
+2
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
...ary/reference_tensor_operation/cpu/reference_conv_fwd.hpp
+5
-4
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp
...nsor_operation/cpu/reference_conv_fwd_bias_activation.hpp
+2
-1
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
6e3cf8b0
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include "multi_index_transform_helper.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
...
@@ -185,12 +186,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -185,12 +186,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
const
Block2CTileMap
&
block_2_ctile_map
)
index_t
N01
)
{
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
"wrong! K1 need to be known at compile-time"
);
...
@@ -219,31 +220,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -219,31 +220,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return
false
;
return
false
;
}
}
// check M01, N01
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
{
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
if
(
!
(
M0
%
M01
==
0
&&
N0
%
N01
==
0
))
return
false
;
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
return
true
;
}
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
{
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
...
@@ -305,36 +290,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -305,36 +290,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
{
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
return
BlockToCTileMap_M00_N00_M01_N01
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
c_grid_desc_m_n
,
M01
,
N01
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}));
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_m0_n0_block_cluster_adaptor
;
}
}
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
...
@@ -368,6 +325,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -368,6 +325,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const
auto
block_work_idx
=
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I0
),
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I1
))))
{
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
View file @
6e3cf8b0
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "multi_index_transform_helper.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
...
@@ -167,12 +168,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
...
@@ -167,12 +168,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
ABK0MK1GridDesc
&
a_b_k0_m_k1_grid_desc
,
CheckValidity
(
const
ABK0MK1GridDesc
&
a_b_k0_m_k1_grid_desc
,
const
BBK0NK1GridDesc
&
b_b_k0_n_k1_grid_desc
,
const
BBK0NK1GridDesc
&
b_b_k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
M01
,
const
Block2CTileMap
&
block_2_ctile_map
)
index_t
N01
)
{
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
"wrong! K1 need to be known at compile-time"
);
...
@@ -196,31 +197,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
...
@@ -196,31 +197,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
return
false
;
return
false
;
// check M01, N01
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_m_n_grid_desc
))
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
{
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
if
(
!
(
M0
%
M01
==
0
&&
N0
%
N01
==
0
))
return
false
;
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
return
true
;
}
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
KBatch
)
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
)
*
KBatch
;
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
{
const
bool
has_main_k0_block_loop
=
K0
>
K0PerBlock
;
const
bool
has_main_k0_block_loop
=
K0
>
K0PerBlock
;
...
@@ -282,37 +267,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
...
@@ -282,37 +267,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
__host__
__device__
static
constexpr
auto
MakeCBlockClusterAdaptor
(
__host__
__device__
static
constexpr
auto
MakeCBlockClusterAdaptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
M01
,
index_t
N01
,
index_t
KBatch
)
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
M01
,
index_t
N01
,
index_t
KBatch
)
{
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
return
BlockToCTileMap_KSplit_M00_N00_M01_N01
<
MPerBlock
,
NPerBlock
,
CMNGridDesc
>
(
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
c_m_n_grid_desc
,
M01
,
N01
,
KBatch
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_pass_through_transform
(
KBatch
),
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{}));
const
auto
cblockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
KBatch
,
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_kbatch_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_kbatch_m0_n0_block_cluster_adaptor
;
}
}
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
CMNGridDesc
{}));
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
CMNGridDesc
{}));
...
@@ -344,6 +300,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
...
@@ -344,6 +300,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
const
auto
block_work_idx
=
const
auto
block_work_idx
=
c_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
c_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
c_block_cluster_adaptor
.
ValidCTileIndex
(
make_tuple
(
block_work_idx
[
I1
],
block_work_idx
[
I2
]),
make_tuple
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
.
GetLength
(
I0
),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
.
GetLength
(
I1
))))
{
return
;
}
const
index_t
k_batch_id
=
block_work_idx
[
I0
];
const
index_t
k_batch_id
=
block_work_idx
[
I0
];
// HACK: this force m/n_block_data_idx_on_grid into SGPR
// HACK: this force m/n_block_data_idx_on_grid into SGPR
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
6e3cf8b0
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "multi_index_transform_helper.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
...
@@ -174,12 +175,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -174,12 +175,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_B_K0_M_K1
&
a_b_k0_m_k1_grid_desc
,
CheckValidity
(
const
AGridDesc_B_K0_M_K1
&
a_b_k0_m_k1_grid_desc
,
const
BGridDesc_B_K0_N_K1
&
b_b_k0_n_k1_grid_desc
,
const
BGridDesc_B_K0_N_K1
&
b_b_k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
M01
,
const
Block2CTileMap
&
block_2_ctile_map
)
index_t
N01
)
{
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
"wrong! K1 need to be known at compile-time"
);
...
@@ -203,31 +204,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -203,31 +204,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
return
false
;
return
false
;
// check M01, N01
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_m_n_grid_desc
))
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
{
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
if
(
!
(
M0
%
M01
==
0
&&
N0
%
N01
==
0
))
return
false
;
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
return
true
;
}
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
KBatch
)
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
)
*
KBatch
;
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
{
const
bool
has_main_k0_block_loop
=
K0
>
K0PerBlock
;
const
bool
has_main_k0_block_loop
=
K0
>
K0PerBlock
;
...
@@ -256,37 +241,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -256,37 +241,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__host__
__device__
static
constexpr
auto
MakeCBlockClusterAdaptor
(
__host__
__device__
static
constexpr
auto
MakeCBlockClusterAdaptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
M01
,
index_t
N01
,
index_t
KBatch
)
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
M01
,
index_t
N01
,
index_t
KBatch
)
{
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
return
BlockToCTileMap_KSplit_M00_N00_M01_N01
<
MPerBlock
,
NPerBlock
,
CMNGridDesc
>
(
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
c_m_n_grid_desc
,
M01
,
N01
,
KBatch
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_pass_through_transform
(
KBatch
),
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{}));
const
auto
c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
KBatch
,
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
c_blockid_to_kbatch_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor
);
return
c_blockid_to_kbatch_m0_n0_block_cluster_adaptor
;
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
...
@@ -333,6 +289,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -333,6 +289,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const
auto
block_work_idx
=
const
auto
block_work_idx
=
c_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
c_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
c_block_cluster_adaptor
.
ValidCTileIndex
(
make_tuple
(
block_work_idx
[
I1
],
block_work_idx
[
I2
]),
make_tuple
(
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
const
index_t
k_batch_id
=
block_work_idx
[
I0
];
const
index_t
k_batch_id
=
block_work_idx
[
I0
];
// HACK: this force m/n_block_data_idx_on_grid into SGPR
// HACK: this force m/n_block_data_idx_on_grid into SGPR
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
View file @
6e3cf8b0
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include "multi_index_transform_helper.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
...
@@ -223,12 +224,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -223,12 +224,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
const
Block2CTileMap
&
block_2_ctile_map
)
index_t
N01
)
{
{
// static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
// static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
// is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
// is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
...
@@ -256,31 +257,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -256,31 +257,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
return
false
;
return
false
;
}
}
// check M01, N01
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
{
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
if
(
!
(
M0
%
M01
==
0
&&
N0
%
N01
==
0
))
return
false
;
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
return
true
;
}
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
{
const
index_t
num_loop
=
K
/
KPerBlock
;
const
index_t
num_loop
=
K
/
KPerBlock
;
...
@@ -318,36 +303,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -318,36 +303,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
{
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
return
BlockToCTileMap_M00_N00_M01_N01
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
c_grid_desc_m_n
,
M01
,
N01
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}));
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_m0_n0_block_cluster_adaptor
;
}
}
using
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
=
using
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
=
remove_cvref_t
<
decltype
(
remove_cvref_t
<
decltype
(
...
@@ -385,6 +342,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -385,6 +342,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
const
auto
block_work_idx
=
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.
GetLength
(
I0
),
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.
GetLength
(
I3
))))
{
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
View file @
6e3cf8b0
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "multi_index_transform_helper.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r2.hpp"
#include "thread_group_tensor_slice_transfer_v6r2.hpp"
...
@@ -230,12 +231,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
...
@@ -230,12 +231,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
const
Block2CTileMap
&
block_2_ctile_map
)
index_t
N01
)
{
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
"wrong! K1 need to be known at compile-time"
);
...
@@ -264,31 +265,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
...
@@ -264,31 +265,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
return
false
;
return
false
;
}
}
// check M01, N01
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
{
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
if
(
!
(
M0
%
M01
==
0
&&
N0
%
N01
==
0
))
return
false
;
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
return
true
;
}
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
{
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
...
@@ -327,37 +312,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
...
@@ -327,37 +312,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
{
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
return
BlockToCTileMap_M00_N00_M01_N01
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
c_grid_desc_m_n
,
M01
,
N01
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}));
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_m0_n0_block_cluster_adaptor
;
}
}
using
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
=
using
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
=
remove_cvref_t
<
decltype
(
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
(
...
@@ -408,6 +366,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
...
@@ -408,6 +366,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
const
auto
block_work_idx
=
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.
GetLength
(
I0
),
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.
GetLength
(
I3
))))
{
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
View file @
6e3cf8b0
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include "multi_index_transform_helper.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r3.hpp"
#include "thread_group_tensor_slice_transfer_v6r3.hpp"
...
@@ -237,12 +238,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
...
@@ -237,12 +238,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
const
Block2CTileMap
&
block_2_ctile_map
)
index_t
N01
)
{
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
"wrong! K1 need to be known at compile-time"
);
...
@@ -271,31 +272,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
...
@@ -271,31 +272,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
return
false
;
return
false
;
}
}
// check M01, N01
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
{
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
if
(
!
(
M0
%
M01
==
0
&&
N0
%
N01
==
0
))
return
false
;
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
return
true
;
}
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
{
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
...
@@ -334,36 +319,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
...
@@ -334,36 +319,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
{
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
return
BlockToCTileMap_M00_N00_M01_N01
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
c_grid_desc_m_n
,
M01
,
N01
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}));
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_m0_n0_block_cluster_adaptor
;
}
}
using
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
=
using
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
=
remove_cvref_t
<
decltype
(
remove_cvref_t
<
decltype
(
...
@@ -427,6 +384,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
...
@@ -427,6 +384,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
const
auto
block_work_idx
=
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.
GetLength
(
I0
),
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.
GetLength
(
I3
))))
{
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
6e3cf8b0
...
@@ -258,6 +258,14 @@ __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32(
...
@@ -258,6 +258,14 @@ __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32(
index_t
soffset
,
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fadd.f32"
);
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fadd.f32"
);
// buffer atomic-add fp32
__device__
double
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
double
vdata
,
int32x4_t
rsrc
,
// dst_wave_buffer_resource
int
voffset
,
// dst_thread_addr_offset
int
soffset
,
// dst_wave_addr_offset
int
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fmax.f64"
);
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
__device__
typename
vector_type
<
T
,
N
>::
type
amd_buffer_load_impl
(
int32x4_t
src_wave_buffer_resource
,
__device__
typename
vector_type
<
T
,
N
>::
type
amd_buffer_load_impl
(
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_thread_addr_offset
,
...
@@ -915,6 +923,71 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::typ
...
@@ -915,6 +923,71 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::typ
}
}
}
}
template
<
typename
T
,
index_t
N
>
__device__
void
amd_buffer_atomic_max_impl
(
const
typename
vector_type
<
T
,
N
>::
type
src_thread_data
,
int32x4_t
dst_wave_buffer_resource
,
index_t
dst_thread_addr_offset
,
index_t
dst_wave_addr_offset
)
{
static_assert
((
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
)),
"wrong! not implemented"
);
if
constexpr
(
is_same
<
T
,
double
>::
value
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
2
)
{
vector_type
<
double
,
2
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
tmp
.
AsType
<
double
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
tmp
.
AsType
<
double
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
double
),
0
);
}
else
if
constexpr
(
N
==
4
)
{
vector_type
<
double
,
4
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
tmp
.
AsType
<
double
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
tmp
.
AsType
<
double
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
double
),
0
);
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
tmp
.
AsType
<
double
>
()[
Number
<
2
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
2
*
sizeof
(
double
),
0
);
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
tmp
.
AsType
<
double
>
()[
Number
<
3
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
3
*
sizeof
(
double
),
0
);
}
}
}
// buffer_load requires:
// buffer_load requires:
// 1) p_src_wave must point to global memory space
// 1) p_src_wave must point to global memory space
// 2) p_src_wave must be a wavewise pointer.
// 2) p_src_wave must be a wavewise pointer.
...
@@ -1046,4 +1119,39 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
...
@@ -1046,4 +1119,39 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
#endif
#endif
}
}
// buffer_atomic_max requires:
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
>
__device__
void
amd_buffer_atomic_max
(
const
typename
vector_type_maker
<
T
,
N
>::
type
::
type
src_thread_data
,
T
*
p_dst_wave
,
const
index_t
dst_thread_element_offset
,
const
bool
dst_thread_element_valid
,
const
index_t
dst_element_space_size
)
{
const
int32x4_t
dst_wave_buffer_resource
=
make_wave_buffer_resource
(
p_dst_wave
,
dst_element_space_size
);
index_t
dst_thread_addr_offset
=
dst_thread_element_offset
*
sizeof
(
T
);
using
vector_t
=
typename
vector_type_maker
<
T
,
N
>::
type
::
type
;
using
scalar_t
=
typename
scalar_type
<
vector_t
>::
type
;
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x7fffffff
;
amd_buffer_atomic_max_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#else
if
(
dst_thread_element_valid
)
{
amd_buffer_atomic_max_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
#endif
}
}
// namespace ck
}
// namespace ck
include/ck/utility/common_header.hpp
View file @
6e3cf8b0
...
@@ -32,7 +32,7 @@
...
@@ -32,7 +32,7 @@
#include "debug.hpp"
#include "debug.hpp"
#include "amd_buffer_addressing.hpp"
#include "amd_buffer_addressing.hpp"
#include "generic_memory_space_atomic
_add
.hpp"
#include "generic_memory_space_atomic.hpp"
#include "get_id.hpp"
#include "get_id.hpp"
#include "synchronization.hpp"
#include "synchronization.hpp"
#include "amd_address_space.hpp"
#include "amd_address_space.hpp"
...
...
include/ck/utility/dynamic_buffer.hpp
View file @
6e3cf8b0
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include "enable_if.hpp"
#include "enable_if.hpp"
#include "c_style_pointer_cast.hpp"
#include "c_style_pointer_cast.hpp"
#include "amd_buffer_addressing.hpp"
#include "amd_buffer_addressing.hpp"
#include "generic_memory_space_atomic
_add
.hpp"
#include "generic_memory_space_atomic.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -125,6 +125,10 @@ struct DynamicBuffer
...
@@ -125,6 +125,10 @@ struct DynamicBuffer
{
{
this
->
template
AtomicAdd
<
X
>(
i
,
is_valid_element
,
x
);
this
->
template
AtomicAdd
<
X
>(
i
,
is_valid_element
,
x
);
}
}
else
if
constexpr
(
Op
==
InMemoryDataOperationEnum
::
AtomicMax
)
{
this
->
template
AtomicMax
<
X
>(
i
,
is_valid_element
,
x
);
}
else
if
constexpr
(
Op
==
InMemoryDataOperationEnum
::
Add
)
else
if
constexpr
(
Op
==
InMemoryDataOperationEnum
::
Add
)
{
{
auto
tmp
=
this
->
template
Get
<
X
>(
i
,
is_valid_element
);
auto
tmp
=
this
->
template
Get
<
X
>(
i
,
is_valid_element
);
...
@@ -326,6 +330,42 @@ struct DynamicBuffer
...
@@ -326,6 +330,42 @@ struct DynamicBuffer
}
}
}
}
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
void
AtomicMax
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
scalar_type
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
static_assert
(
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
,
"only support global mem"
);
#if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64
using
scalar_t
=
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
;
bool
constexpr
use_amd_buffer_addressing
=
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
double
>
;
#else
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
if
constexpr
(
use_amd_buffer_addressing
)
{
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_max
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
}
else
if
(
is_valid_element
)
{
atomic_max
<
X
>
(
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
]),
x
);
}
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
true
;
}
...
...
include/ck/utility/generic_memory_space_atomic
_add
.hpp
→
include/ck/utility/generic_memory_space_atomic.hpp
View file @
6e3cf8b0
...
@@ -3,6 +3,10 @@
...
@@ -3,6 +3,10 @@
namespace
ck
{
namespace
ck
{
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to make the implementation of atomic_add explicit for
// each datatype.
template
<
typename
X
>
template
<
typename
X
>
__device__
X
atomic_add
(
X
*
p_dst
,
const
X
&
x
);
__device__
X
atomic_add
(
X
*
p_dst
,
const
X
&
x
);
...
@@ -41,4 +45,53 @@ __device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x)
...
@@ -41,4 +45,53 @@ __device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x)
return
vy
.
template
AsType
<
float2_t
>()[
I0
];
return
vy
.
template
AsType
<
float2_t
>()[
I0
];
}
}
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to make the implementation of atomic_max explicit for
// each datatype.
template
<
typename
X
>
__device__
X
atomic_max
(
X
*
p_dst
,
const
X
&
x
);
template
<
>
__device__
int32_t
atomic_max
<
int32_t
>
(
int32_t
*
p_dst
,
const
int32_t
&
x
)
{
return
atomicMax
(
p_dst
,
x
);
}
template
<
>
__device__
uint32_t
atomic_max
<
uint32_t
>
(
uint32_t
*
p_dst
,
const
uint32_t
&
x
)
{
return
atomicMax
(
p_dst
,
x
);
}
template
<
>
__device__
float
atomic_max
<
float
>
(
float
*
p_dst
,
const
float
&
x
)
{
return
atomicMax
(
p_dst
,
x
);
}
template
<
>
__device__
double
atomic_max
<
double
>
(
double
*
p_dst
,
const
double
&
x
)
{
return
atomicMax
(
p_dst
,
x
);
}
template
<
>
__device__
float2_t
atomic_max
<
float2_t
>
(
float2_t
*
p_dst
,
const
float2_t
&
x
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
const
vector_type
<
float
,
2
>
vx
{
x
};
vector_type
<
float
,
2
>
vy
{
0
};
vy
.
template
AsType
<
float
>()(
I0
)
=
atomicMax
(
c_style_pointer_cast
<
float
*>
(
p_dst
),
vx
.
template
AsType
<
float
>()[
I0
]);
vy
.
template
AsType
<
float
>()(
I1
)
=
atomicMax
(
c_style_pointer_cast
<
float
*>
(
p_dst
)
+
1
,
vx
.
template
AsType
<
float
>()[
I1
]);
return
vy
.
template
AsType
<
float2_t
>()[
I0
];
}
}
// namespace ck
}
// namespace ck
include/ck/utility/get_id.hpp
View file @
6e3cf8b0
...
@@ -11,10 +11,14 @@ __host__ __device__ constexpr index_t get_warp_size()
...
@@ -11,10 +11,14 @@ __host__ __device__ constexpr index_t get_warp_size()
__device__
index_t
get_thread_local_1d_id
()
{
return
threadIdx
.
x
;
}
__device__
index_t
get_thread_local_1d_id
()
{
return
threadIdx
.
x
;
}
__device__
index_t
get_thread_global_1d_id
()
{
return
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
}
__device__
index_t
get_warp_local_1d_id
()
{
return
threadIdx
.
x
/
get_warp_size
();
}
__device__
index_t
get_warp_local_1d_id
()
{
return
threadIdx
.
x
/
get_warp_size
();
}
__device__
index_t
get_block_1d_id
()
{
return
blockIdx
.
x
;
}
__device__
index_t
get_block_1d_id
()
{
return
blockIdx
.
x
;
}
__device__
index_t
get_grid_size
()
{
return
gridDim
.
x
;
}
__device__
index_t
get_grid_size
()
{
return
gridDim
.
x
;
}
__device__
index_t
get_block_size
()
{
return
blockDim
.
x
;
}
}
// namespace ck
}
// namespace ck
include/ck/utility/statically_indexed_array_multi_index.hpp
View file @
6e3cf8b0
...
@@ -93,6 +93,13 @@ __host__ __device__ constexpr auto operator*(index_t a, const Tuple<Xs...>& x)
...
@@ -93,6 +93,13 @@ __host__ __device__ constexpr auto operator*(index_t a, const Tuple<Xs...>& x)
return
r
;
return
r
;
}
}
// MultiIndex = MultiIndex * index_t
template
<
typename
...
Xs
>
__host__
__device__
constexpr
auto
operator
*
(
const
Tuple
<
Xs
...
>&
x
,
index_t
a
)
{
return
a
*
x
;
}
template
<
typename
...
Xs
>
template
<
typename
...
Xs
>
__host__
__device__
void
print_multi_index
(
const
Tuple
<
Xs
...
>&
x
)
__host__
__device__
void
print_multi_index
(
const
Tuple
<
Xs
...
>&
x
)
{
{
...
...
include/ck/utility/type.hpp
View file @
6e3cf8b0
...
@@ -29,6 +29,9 @@ using remove_cv_t = typename std::remove_cv<T>::type;
...
@@ -29,6 +29,9 @@ using remove_cv_t = typename std::remove_cv<T>::type;
template
<
typename
T
>
template
<
typename
T
>
using
remove_cvref_t
=
remove_cv_t
<
std
::
remove_reference_t
<
T
>>
;
using
remove_cvref_t
=
remove_cv_t
<
std
::
remove_reference_t
<
T
>>
;
template
<
typename
T
>
using
remove_pointer_t
=
typename
std
::
remove_pointer
<
T
>::
type
;
template
<
typename
T
>
template
<
typename
T
>
inline
constexpr
bool
is_pointer_v
=
std
::
is_pointer
<
T
>::
value
;
inline
constexpr
bool
is_pointer_v
=
std
::
is_pointer
<
T
>::
value
;
...
...
library/include/ck/library/host/host_interface.hpp
0 → 100644
View file @
6e3cf8b0
#pragma once
#include <memory>
#include <string>
#include "stream_config.hpp"
#include "config.hpp"
#include "device_base.hpp"
struct
DeviceConvFwdPtr_t
{
using
BaseArgument
=
ck
::
tensor_operation
::
device
::
BaseArgument
;
using
BaseInvoker
=
ck
::
tensor_operation
::
device
::
BaseInvoker
;
struct
DeviceConvFwdPtrImpl
;
std
::
unique_ptr
<
DeviceConvFwdPtrImpl
>
pImpl
;
DeviceConvFwdPtr_t
();
~
DeviceConvFwdPtr_t
();
DeviceConvFwdPtr_t
(
DeviceConvFwdPtr_t
&&
);
DeviceConvFwdPtr_t
(
DeviceConvFwdPtrImpl
&
);
DeviceConvFwdPtr_t
&
operator
=
(
DeviceConvFwdPtr_t
&
)
=
delete
;
DeviceConvFwdPtr_t
&
operator
=
(
const
DeviceConvFwdPtr_t
&
)
=
delete
;
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
void
*
in_ptr
,
void
*
wei_ptr
,
void
*
out_ptr
,
size_t
N
,
size_t
K
,
size_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
const
;
// in,wei and out element ops are ignored for now since even if we change them, they
// cant be linked
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
const
;
// requires including BaseInvoker headers
std
::
string
GetTypeString
();
bool
IsSupportedArgument
(
const
BaseArgument
*
arg_ptr
);
};
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
void
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
library/include/ck/library/host_tensor/device.hpp
View file @
6e3cf8b0
#ifndef DEVICE_HPP
#pragma once
#define DEVICE_HPP
#include <memory>
#include <memory>
#include <functional>
#include <functional>
#include <thread>
#include <thread>
#include <chrono>
#include <chrono>
#include "hip/hip_runtime.h"
#include <hip/hip_runtime.h>
#include "hip/hip_fp16.h"
#include <hip/hip_fp16.h>
#include "stream_config.hpp"
#include "ck/options.hpp"
template
<
typename
T
>
__global__
void
set_buffer_value
(
T
*
p
,
T
x
,
uint64_t
buffer_element_size
)
{
for
(
uint64_t
i
=
threadIdx
.
x
;
i
<
buffer_element_size
;
i
+=
blockDim
.
x
)
{
p
[
i
]
=
x
;
}
}
inline
void
hip_check_error
(
hipError_t
x
)
{
if
(
x
!=
hipSuccess
)
{
std
::
ostringstream
ss
;
ss
<<
"HIP runtime error: "
<<
hipGetErrorString
(
x
)
<<
". "
<<
__FILE__
<<
": "
<<
__LINE__
<<
"in function: "
<<
__func__
;
throw
std
::
runtime_error
(
ss
.
str
());
}
}
struct
DeviceMem
struct
DeviceMem
{
{
...
@@ -17,6 +39,16 @@ struct DeviceMem
...
@@ -17,6 +39,16 @@ struct DeviceMem
void
ToDevice
(
const
void
*
p
);
void
ToDevice
(
const
void
*
p
);
void
FromDevice
(
void
*
p
);
void
FromDevice
(
void
*
p
);
void
SetZero
();
void
SetZero
();
template
<
typename
T
>
void
SetValue
(
T
x
)
{
if
(
mMemSize
%
sizeof
(
T
)
!=
0
)
{
throw
std
::
runtime_error
(
"wrong! not entire DeviceMem will be set"
);
}
set_buffer_value
<
T
><<<
1
,
1024
>>>
(
static_cast
<
T
*>
(
mpDeviceBuf
),
x
,
mMemSize
/
sizeof
(
T
));
}
~
DeviceMem
();
~
DeviceMem
();
void
*
mpDeviceBuf
;
void
*
mpDeviceBuf
;
...
@@ -36,49 +68,56 @@ struct KernelTimer
...
@@ -36,49 +68,56 @@ struct KernelTimer
std
::
unique_ptr
<
KernelTimerImpl
>
impl
;
std
::
unique_ptr
<
KernelTimerImpl
>
impl
;
};
};
using
device_stream_t
=
hipStream_t
;
template
<
typename
...
Args
,
typename
F
>
template
<
typename
...
Args
,
typename
F
>
void
launch_kernel
(
F
kernel
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
Args
...
args
)
float
launch_and_time_kernel
(
const
StreamConfig
&
stream_config
,
F
kernel
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
Args
...
args
)
{
{
hipStream_t
stream_id
=
nullptr
;
#if CK_TIME_KERNEL
if
(
stream_config
.
time_kernel_
)
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_id
,
args
...);
{
}
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
__func__
,
grid_dim
.
x
,
grid_dim
.
y
,
grid_dim
.
z
,
block_dim
.
x
,
block_dim
.
y
,
block_dim
.
z
);
template
<
typename
...
Args
,
typename
F
>
const
int
nrepeat
=
10
;
float
launch_and_time_kernel
(
F
kernel
,
int
nrepeat
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
Args
...
args
)
{
KernelTimer
timer
;
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
printf
(
"Warm up 1 time
\n
"
);
__func__
,
grid_dim
.
x
,
grid_dim
.
y
,
grid_dim
.
z
,
block_dim
.
x
,
block_dim
.
y
,
block_dim
.
z
);
printf
(
"Warm up
\n
"
);
// warm up
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
hipStream_t
stream_id
=
nullptr
;
printf
(
"Start running %d times...
\n
"
,
nrepeat
)
;
// warm up
KernelTimer
timer
;
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_id
,
args
...
);
timer
.
Start
(
);
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
}
timer
.
Start
();
timer
.
End
();
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
return
timer
.
GetElapsedTime
()
/
nrepeat
;
{
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_id
,
args
...);
}
}
else
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
timer
.
End
();
return
0
;
}
#else
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
return
timer
.
GetElapsedTime
()
/
nrepeat
;
return
0
;
}
#endif
#endif
}
library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp
View file @
6e3cf8b0
...
@@ -84,7 +84,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator
...
@@ -84,7 +84,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator
return
0
;
return
0
;
}
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
int
)
override
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp
View file @
6e3cf8b0
#ifndef REFERENCE_CONV_WRW_HPP
#pragma once
#define REFERENCE_CONV_WRW_HPP
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
...
@@ -16,7 +15,9 @@ template <typename InDataType,
...
@@ -16,7 +15,9 @@ template <typename InDataType,
typename
OutDataType
,
typename
OutDataType
,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
typename
OutElementwiseOperation
,
ck
::
index_t
NumDimSpatial
=
2
,
typename
ck
::
enable_if
<
NumDimSpatial
>
=
1
&&
NumDimSpatial
<=
3
,
bool
>::
type
=
false
>
struct
ReferenceConvBwdWeight
:
public
device
::
BaseOperator
struct
ReferenceConvBwdWeight
:
public
device
::
BaseOperator
{
{
// Argument
// Argument
...
@@ -32,9 +33,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
...
@@ -32,9 +33,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
InElementwiseOperation
in_element_op
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
OutElementwiseOperation
out_element_op
)
:
in
_n_c_hi_wi
_
{
in_n_c_hi_wi
},
:
in
put
_
{
in_n_c_hi_wi
},
wei
_k_c_y_x
_
{
wei_k_c_y_x
},
wei
ght
_
{
wei_k_c_y_x
},
out
_n_k_ho_wo
_
{
out_n_k_ho_wo
},
out
put
_
{
out_n_k_ho_wo
},
conv_strides_
{
conv_filter_strides
},
conv_strides_
{
conv_filter_strides
},
conv_dilations_
{
conv_filter_dilations
},
conv_dilations_
{
conv_filter_dilations
},
in_left_pads_
{
input_left_pads
},
in_left_pads_
{
input_left_pads
},
...
@@ -45,9 +46,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
...
@@ -45,9 +46,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
{
{
}
}
const
Tensor
<
InDataType
>&
in
_n_c_hi_wi
_
;
const
Tensor
<
InDataType
>&
in
put
_
;
Tensor
<
WeiDataType
>&
wei
_k_c_y_x
_
;
Tensor
<
WeiDataType
>&
wei
ght
_
;
const
Tensor
<
OutDataType
>&
out
_n_k_ho_wo
_
;
const
Tensor
<
OutDataType
>&
out
put
_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_dilations_
;
std
::
vector
<
index_t
>
conv_dilations_
;
...
@@ -66,62 +67,184 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
...
@@ -66,62 +67,184 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
float
Run
(
const
Argument
&
arg
)
float
Run
(
const
Argument
&
arg
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
if
constexpr
(
NumDimSpatial
==
1
)
constexpr
auto
I1
=
Number
<
1
>
{};
{
auto
f_kcyx
=
[
&
](
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
float
v_acc
=
0
;
auto
f_kcx
=
[
&
](
auto
k
,
auto
c
,
auto
x
)
{
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
out_n_k_ho_wo_
.
mDesc
.
GetLengths
()[
0
];
++
n
)
float
v_acc
=
0
;
{
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
output_
.
mDesc
.
GetLengths
()[
0
];
++
n
)
for
(
std
::
size_t
ho
=
0
;
ho
<
arg
.
out_n_k_ho_wo_
.
mDesc
.
GetLengths
()[
2
];
++
ho
)
{
{
auto
hi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
I0
])
+
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
++
wo
)
ck
::
type_convert
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
I0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I0
]);
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
out_n_k_ho_wo_
.
mDesc
.
GetLengths
()[
3
];
++
wo
)
{
{
auto
wi
=
auto
wi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
I1
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
I0
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
I1
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
I0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I1
]);
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I0
]);
if
(
hi
>=
0
&&
if
(
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
])
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
3
])
{
{
float
v_out
;
float
v_out
;
float
v_in
;
float
v_in
;
arg
.
out_element_op_
(
arg
.
out_element_op_
(
v_out
,
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
n
,
k
,
wo
)));
ck
::
type_convert
<
float
>
(
arg
.
out_n_k_ho_wo_
(
n
,
k
,
ho
,
wo
)));
arg
.
in_element_op_
(
v_in
,
arg
.
in_element_op_
(
ck
::
type_convert
<
float
>
(
arg
.
input_
(
n
,
c
,
wi
)));
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
in_n_c_hi_wi_
(
n
,
c
,
hi
,
wi
)));
v_acc
+=
v_out
*
v_in
;
v_acc
+=
v_out
*
v_in
;
}
}
}
}
}
}
}
float
v_wei
;
float
v_wei
;
arg
.
wei_element_op_
(
v_wei
,
v_acc
);
arg
.
wei_element_op_
(
v_wei
,
v_acc
);
arg
.
wei
_k_c_y_x
_
(
k
,
c
,
y
,
x
)
=
ck
::
type_convert
<
Out
DataType
>
(
v_wei
);
arg
.
wei
ght
_
(
k
,
c
,
x
)
=
ck
::
type_convert
<
Wei
DataType
>
(
v_wei
);
};
};
make_ParallelTensorFunctor
(
f_kcyx
,
make_ParallelTensorFunctor
(
f_kcx
,
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
0
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
],
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
1
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
],
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
2
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
])(
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
std
::
thread
::
hardware_concurrency
());
return
0
;
return
0
;
}
else
if
constexpr
(
NumDimSpatial
==
2
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
auto
f_kcyx
=
[
&
](
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
float
v_acc
=
0
;
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
output_
.
mDesc
.
GetLengths
()[
0
];
++
n
)
{
for
(
std
::
size_t
ho
=
0
;
ho
<
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
++
ho
)
{
auto
hi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
I0
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
I0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I0
]);
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
output_
.
mDesc
.
GetLengths
()[
3
];
++
wo
)
{
auto
wi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
I1
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
I1
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I1
]);
if
(
hi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
3
])
{
float
v_out
;
float
v_in
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
n
,
k
,
ho
,
wo
)));
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
n
,
c
,
hi
,
wi
)));
v_acc
+=
v_out
*
v_in
;
}
}
}
}
float
v_wei
;
arg
.
wei_element_op_
(
v_wei
,
v_acc
);
arg
.
weight_
(
k
,
c
,
y
,
x
)
=
ck
::
type_convert
<
WeiDataType
>
(
v_wei
);
};
make_ParallelTensorFunctor
(
f_kcyx
,
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
else
if
constexpr
(
NumDimSpatial
==
3
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
auto
f_kczyx
=
[
&
](
auto
k
,
auto
c
,
auto
z
,
auto
y
,
auto
x
)
{
float
v_acc
=
0
;
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
output_
.
mDesc
.
GetLengths
()[
0
];
++
n
)
{
for
(
std
::
size_t
do_
=
0
;
do_
<
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
++
do_
)
{
auto
di
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
do_
*
arg
.
conv_strides_
[
I0
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
I0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I0
]);
for
(
std
::
size_t
ho
=
0
;
ho
<
arg
.
output_
.
mDesc
.
GetLengths
()[
3
];
++
ho
)
{
auto
hi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
I1
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
I1
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I1
]);
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
output_
.
mDesc
.
GetLengths
()[
4
];
++
wo
)
{
auto
wi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
I2
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
I2
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I2
]);
if
(
di
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
di
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]
&&
hi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
3
]
&&
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
4
])
{
float
v_out
;
float
v_in
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
n
,
k
,
do_
,
ho
,
wo
)));
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
n
,
c
,
di
,
hi
,
wi
)));
v_acc
+=
v_out
*
v_in
;
}
}
}
}
}
float
v_wei
;
arg
.
wei_element_op_
(
v_wei
,
v_acc
);
arg
.
weight_
(
k
,
c
,
z
,
y
,
x
)
=
ck
::
type_convert
<
WeiDataType
>
(
v_wei
);
};
make_ParallelTensorFunctor
(
f_kczyx
,
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
4
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
}
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
int
)
override
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/*stream_config*/
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
...
@@ -181,4 +304,3 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
...
@@ -181,4 +304,3 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
}
// namespace host
}
// namespace host
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
View file @
6e3cf8b0
...
@@ -291,7 +291,8 @@ struct ReferenceConvBwdData : public device::BaseOperator
...
@@ -291,7 +291,8 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
}
}
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
int
)
override
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
View file @
6e3cf8b0
#ifndef REFERENCE_CONV_FWD_HPP
#pragma once
#define REFERENCE_CONV_FWD_HPP
#include <iostream>
#include <iostream>
#include <type_traits>
#include <type_traits>
#include <sstream>
#include <sstream>
#include "stream_config.hpp"
#include "device_base.hpp"
#include "device_base.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
...
@@ -251,7 +252,8 @@ struct ReferenceConvFwd : public device::BaseOperator
...
@@ -251,7 +252,8 @@ struct ReferenceConvFwd : public device::BaseOperator
}
}
}
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
int
)
override
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/*stream_config*/
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
...
@@ -311,4 +313,3 @@ struct ReferenceConvFwd : public device::BaseOperator
...
@@ -311,4 +313,3 @@ struct ReferenceConvFwd : public device::BaseOperator
}
// namespace host
}
// namespace host
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp
View file @
6e3cf8b0
...
@@ -124,7 +124,8 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
...
@@ -124,7 +124,8 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
return
0
;
return
0
;
}
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
int
)
override
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
...
...
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment