Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0af93458
Commit
0af93458
authored
Dec 15, 2021
by
Chao Liu
Browse files
clean up
parent
50d7b4fc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
259 deletions
+47
-259
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
+15
-192
device_operation/include/device_conv2d_fwd_xdl_output_shuffle_nhwc_kyxc_nhwk.hpp
...e/device_conv2d_fwd_xdl_output_shuffle_nhwc_kyxc_nhwk.hpp
+32
-67
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
View file @
0af93458
...
@@ -10,8 +10,6 @@
...
@@ -10,8 +10,6 @@
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
#include "threadwise_tensor_slice_set.hpp"
#define DEBUG_USE_C_SHUFFLE 1
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
...
@@ -19,11 +17,7 @@ template <typename GridwiseGemm,
...
@@ -19,11 +17,7 @@ template <typename GridwiseGemm,
typename
FloatC
,
typename
FloatC
,
typename
AGridDesc_K0_M_K1
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
BGridDesc_K0_N_K1
,
#if !DEBUG_USE_C_SHUFFLE
typename
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
#else
typename
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
,
typename
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
,
#endif
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
...
@@ -39,12 +33,8 @@ __global__ void
...
@@ -39,12 +33,8 @@ __global__ void
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
#if !DEBUG_USE_C_SHUFFLE
const
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
,
#else
const
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
,
#endif
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
...
@@ -52,21 +42,18 @@ __global__ void
...
@@ -52,21 +42,18 @@ __global__ void
{
{
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_b_grid
,
p_a_grid
,
p_c_grid
,
p_b_grid
,
p_shared
,
p_c_grid
,
a_grid_desc_k0_m_k1
,
p_shared
,
b_grid_desc_k0_n_k1
,
a_grid_desc_k0_m_k1
,
#if !DEBUG_USE_C_SHUFFLE
b_grid_desc_k0_n_k1
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
,
#else
a_element_op
,
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
,
b_element_op
,
#endif
c_element_op
,
a_element_op
,
block_2_ctile_map
);
b_element_op
,
c_element_op
,
block_2_ctile_map
);
}
}
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
...
@@ -232,57 +219,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -232,57 +219,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
return
has_main_k0_block_loop
;
return
has_main_k0_block_loop
;
}
}
#if !DEBUG_USE_C_SHUFFLE
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0_m_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0_n_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
using
BlockwiseGemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_block_desc_k0_n_k1
),
MPerXdl
,
NPerXdl
,
MRepeat
,
NRepeat
,
K1
>
;
return
BlockwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n
);
}
#else
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
(
MakeCGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
...
@@ -308,7 +244,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -308,7 +244,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
return
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
;
return
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
;
}
}
#endif
// return block_id to C matrix tile idx (m0, n0) mapping
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
...
@@ -345,15 +280,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -345,15 +280,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
return
c_blockid_to_m0_n0_block_cluster_adaptor
;
return
c_blockid_to_m0_n0_block_cluster_adaptor
;
}
}
#if !DEBUG_USE_C_SHUFFLE
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
#else
using
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
=
using
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
=
remove_cvref_t
<
decltype
(
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
(
MakeCGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
(
CGridDesc_M_N
{}))
>
;
CGridDesc_M_N
{}))
>
;
#endif
using
Block2CTileMap
=
remove_cvref_t
<
decltype
(
MakeBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
using
Block2CTileMap
=
remove_cvref_t
<
decltype
(
MakeBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
template
<
bool
HasMainKBlockLoop
>
template
<
bool
HasMainKBlockLoop
>
...
@@ -364,12 +295,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -364,12 +295,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
#if !DEBUG_USE_C_SHUFFLE
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
&
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
#else
const
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
&
const
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
&
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
,
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
,
#endif
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
CElementwiseOperation
&
c_element_op
,
...
@@ -379,16 +306,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -379,16 +306,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
p_a_grid
,
a_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_b_grid
,
b_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
p_b_grid
,
b_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
#if !DEBUG_USE_C_SHUFFLE
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_grid
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
());
#else
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_grid
,
p_c_grid
,
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
.
GetElementSpaceSize
());
.
GetElementSpaceSize
());
#endif
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
...
@@ -526,23 +447,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -526,23 +447,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
constexpr
auto
a_block_space_size
=
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
#if !DEBUG_USE_C_SHUFFLE
FloatAB
*
p_a_block
=
static_cast
<
FloatAB
*>
(
p_shared
);
FloatAB
*
p_b_block
=
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size
;
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_a_block
,
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_b_block
,
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
#else
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size
,
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size
,
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
#endif
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
...
@@ -607,92 +517,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -607,92 +517,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
#if !DEBUG_USE_C_SHUFFLE
// output: register to global memory
{
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I7
);
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_grid
=
m_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_grid
=
n_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I1
];
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
=
CGridStepHacks
{};
const
auto
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_grid_idx
=
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_grid
));
const
auto
n_thread_data_on_grid_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_grid_idx
=
n_thread_data_on_grid_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_grid
));
auto
c_thread_copy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
#if 0
CElementwiseOperation,
#else
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
#endif
Sequence
<
M0
,
N0
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
m_thread_data_on_grid_idx
[
I0
],
n_thread_data_on_grid_idx
[
I0
],
m_thread_data_on_grid_idx
[
I1
],
n_thread_data_on_grid_idx
[
I1
],
m_thread_data_on_grid_idx
[
I2
],
m_thread_data_on_grid_idx
[
I3
],
m_thread_data_on_grid_idx
[
I4
],
n_thread_data_on_grid_idx
[
I2
]),
c_element_op
};
c_thread_copy
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_grid_buf
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
);
}
#else
// shuffle and write out
// shuffle and write out
{
{
#if 1
#if 1
...
@@ -960,7 +784,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -960,7 +784,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
}
}
});
});
}
}
#endif
}
}
};
};
...
...
device_operation/include/device_conv2d_fwd_xdl_output_shuffle_nhwc_kyxc_nhwk.hpp
View file @
0af93458
...
@@ -301,13 +301,6 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
...
@@ -301,13 +301,6 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
ABlockLdsAddExtraM
,
ABlockLdsAddExtraM
,
BBlockLdsAddExtraN
>
;
BBlockLdsAddExtraN
>
;
#if !DEBUG_USE_C_SHUFFLE
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
GridwiseGemm
::
MakeBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
#endif
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -335,11 +328,7 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
...
@@ -335,11 +328,7 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
a_grid_desc_k0_m_k1_
{},
a_grid_desc_k0_m_k1_
{},
b_grid_desc_k0_n_k1_
{},
b_grid_desc_k0_n_k1_
{},
c_grid_desc_m_n_
{},
c_grid_desc_m_n_
{},
#if !DEBUG_USE_C_SHUFFLE
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
#else
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
{},
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
{},
#endif
block_2_ctile_map_
{},
block_2_ctile_map_
{},
M01_
{
M01
},
M01_
{
M01
},
N01_
{
N01
},
N01_
{
N01
},
...
@@ -366,15 +355,10 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
...
@@ -366,15 +355,10 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
if
(
GridwiseGemm
::
CheckValidity
(
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
{
{
#if !DEBUG_USE_C_SHUFFLE
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n_
);
#else
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
=
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
=
GridwiseGemm
::
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
(
MakeCGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
(
c_grid_desc_m_n_
);
c_grid_desc_m_n_
);
#endif
block_2_ctile_map_
=
GridwiseGemm
::
MakeBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
}
}
...
@@ -387,14 +371,9 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
...
@@ -387,14 +371,9 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
#if !DEBUG_USE_C_SHUFFLE
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
#else
typename
GridwiseGemm
::
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
;
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
;
#endif
typename
GridwiseGemm
::
Block2CTileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
Block2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
M01_
;
index_t
N01_
;
index_t
N01_
;
...
@@ -470,38 +449,31 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
...
@@ -470,38 +449,31 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
CDataType
,
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
#if !DEBUG_USE_C_SHUFFLE
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
#else
remove_reference_t
<
remove_reference_t
<
typename
GridwiseGemm
::
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
>
,
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
>
,
#endif
InElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
OutElementwiseOperation
,
remove_reference_t
<
typename
GridwiseGemm
::
Block2CTileMap
>
,
remove_reference_t
<
typename
GridwiseGemm
::
Block2CTileMap
>
,
true
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
nrepeat
,
kernel
,
dim3
(
grid_size
),
nrepeat
,
dim3
(
BlockSize
),
dim3
(
grid_size
),
0
,
dim3
(
BlockSize
),
arg
.
p_a_grid_
,
0
,
arg
.
p_b_grid_
,
arg
.
p_a_grid_
,
arg
.
p_c_grid_
,
arg
.
p_b_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
p_c_grid_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
#if !DEBUG_USE_C_SHUFFLE
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
,
#else
arg
.
in_element_op_
,
arg
.
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
,
arg
.
wei_element_op_
,
#endif
arg
.
out_element_op_
,
arg
.
in_element_op_
,
arg
.
block_2_ctile_map_
);
arg
.
wei_element_op_
,
arg
.
out_element_op_
,
arg
.
block_2_ctile_map_
);
}
}
else
else
{
{
...
@@ -511,38 +483,31 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
...
@@ -511,38 +483,31 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
CDataType
,
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
#if !DEBUG_USE_C_SHUFFLE
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
#else
remove_reference_t
<
remove_reference_t
<
typename
GridwiseGemm
::
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
>
,
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
>
,
#endif
InElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
OutElementwiseOperation
,
remove_reference_t
<
typename
GridwiseGemm
::
Block2CTileMap
>
,
remove_reference_t
<
typename
GridwiseGemm
::
Block2CTileMap
>
,
false
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
nrepeat
,
kernel
,
dim3
(
grid_size
),
nrepeat
,
dim3
(
BlockSize
),
dim3
(
grid_size
),
0
,
dim3
(
BlockSize
),
arg
.
p_a_grid_
,
0
,
arg
.
p_b_grid_
,
arg
.
p_a_grid_
,
arg
.
p_c_grid_
,
arg
.
p_b_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
p_c_grid_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
#if !DEBUG_USE_C_SHUFFLE
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
,
#else
arg
.
in_element_op_
,
arg
.
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
,
arg
.
wei_element_op_
,
#endif
arg
.
out_element_op_
,
arg
.
in_element_op_
,
arg
.
block_2_ctile_map_
);
arg
.
wei_element_op_
,
arg
.
out_element_op_
,
arg
.
block_2_ctile_map_
);
}
}
return
ave_time
;
return
ave_time
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment