Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e6a23d8b
Commit
e6a23d8b
authored
Sep 12, 2021
by
Jing Zhang
Browse files
add e2
parent
a8169558
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
322 additions
and
287 deletions
+322
-287
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
...rnel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
+62
-53
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
...ernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
+126
-115
composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
...nel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
+26
-23
host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
...ution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
+18
-16
host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp
...orward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp
+89
-80
script/run.sh
script/run.sh
+1
-0
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
View file @
e6a23d8b
...
@@ -7,20 +7,21 @@
...
@@ -7,20 +7,21 @@
namespace
ck
{
namespace
ck
{
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
FloatC
,
typename
Block
MatrixA
,
typename
A
Block
Desc_E1_K_E2
,
typename
Block
MatrixB
,
typename
B
Block
Desc_E1_N_Ho_Wo_E2
,
typename
Thread
MatrixC
,
typename
C
Thread
Desc_K_N_Ho_Wo
,
index_t
EPerThreadLoop
,
index_t
EPerThreadLoop
,
index_t
ThreadGemmADataPerRead_K
,
index_t
ThreadGemmADataPerRead_E2
>
index_t
ThreadGemmBDataPerRead_W
>
struct
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
struct
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
struct
MatrixIndex
struct
MatrixIndex
{
{
...
@@ -29,36 +30,48 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
...
@@ -29,36 +30,48 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
index_t
w
;
index_t
w
;
};
};
static
constexpr
index_t
KPerThreadLoop
=
4
;
static
constexpr
auto
E1
=
ABlockDesc_E1_K_E2
{}.
GetLength
(
I0
);
static
constexpr
auto
K
=
ABlockDesc_E1_K_E2
{}.
GetLength
(
I1
);
static
constexpr
auto
E2
=
ABlockDesc_E1_K_E2
{}.
GetLength
(
I2
);
static
constexpr
auto
KPerThread
=
ThreadMatrixC
{}.
GetLength
(
I0
);
static
constexpr
auto
H
=
BBlockDesc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I2
);
static
constexpr
auto
HPerThread
=
ThreadMatrixC
{}.
GetLength
(
I2
);
static
constexpr
auto
W
=
BBlockDesc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I3
);
static
constexpr
auto
WPerThread
=
ThreadMatrixC
{}.
GetLength
(
I3
);
static
constexpr
auto
KPerThread
=
CThreadDesc_K_N_Ho_Wo
{}.
GetLength
(
I0
);
static
constexpr
auto
HPerThread
=
CThreadDesc_K_N_Ho_Wo
{}.
GetLength
(
I2
);
static
constexpr
auto
WPerThread
=
CThreadDesc_K_N_Ho_Wo
{}.
GetLength
(
I3
);
static
constexpr
index_t
KPerThreadLoop
=
KPerThread
;
static
constexpr
auto
a_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
static
constexpr
auto
a_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
KPerThreadLoop
>
{}));
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
KPerThreadLoop
>
{}
,
Number
<
E2
>
{}
));
static
constexpr
auto
b_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
static
constexpr
auto
b_thread_mtx_
=
Number
<
EPerThreadLoop
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{},
Number
<
E2
>
{}));
static
constexpr
auto
c_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
static
constexpr
auto
c_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
Number
<
KPerThreadLoop
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
__device__
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
()
__device__
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
()
:
c_thread_begin_mtx_idx_
{
GetBeginOfThread
MatrixC
(
get_thread_local_1d_id
())},
:
c_thread_begin_mtx_idx_
{
GetBeginOf
C
Thread
Desc_K_N_Ho_Wo
(
get_thread_local_1d_id
())},
a_thread_copy_
{
make_tuple
(
0
,
c_thread_begin_mtx_idx_
.
k
*
KPerThread
)}
a_thread_copy_
{
make_tuple
(
0
,
c_thread_begin_mtx_idx_
.
k
*
KPerThread
,
0
)}
{
{
static_assert
(
Block
MatrixA
::
IsKnownAtCompileTime
()
&&
static_assert
(
A
Block
Desc_E1_K_E2
::
IsKnownAtCompileTime
()
&&
Block
MatrixB
::
IsKnownAtCompileTime
()
&&
B
Block
Desc_E1_N_Ho_Wo_E2
::
IsKnownAtCompileTime
()
&&
Thread
MatrixC
::
IsKnownAtCompileTime
(),
C
Thread
Desc_K_N_Ho_Wo
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
"wrong! Desc should be known at compile-time"
);
static_assert
(
BlockMatrixA
{}.
GetLength
(
I0
)
==
BlockMatrixB
{}.
GetLength
(
I0
),
static_assert
(
"wrong! K dimension not consistent
\n
"
);
ABlockDesc_E1_K_E2
{}.
GetLength
(
I0
)
==
BBlockDesc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I0
)
&&
ABlockDesc_E1_K_E2
{}.
GetLength
(
I2
)
==
BBlockDesc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I4
),
"wrong! E dimension not consistent
\n
"
);
constexpr
index_t
K
=
BlockMatrixA
{}.
GetLength
(
I1
);
// A is transposed
static_assert
(
E1
%
EPerThreadLoop
==
0
,
""
);
constexpr
index_t
H
=
BlockMatrixB
{}.
GetLength
(
I2
);
static_assert
(
KPerThread
%
KPerThreadLoop
==
0
,
""
);
constexpr
index_t
W
=
BlockMatrixB
{}.
GetLength
(
I3
);
static_assert
(
K
%
KPerThread
==
0
&&
H
%
HPerThread
==
0
&&
W
%
WPerThread
==
0
,
static_assert
(
K
%
KPerThread
==
0
&&
H
%
HPerThread
==
0
&&
W
%
WPerThread
==
0
,
"wrong! Cannot evenly divide work among
\n
"
);
"wrong! Cannot evenly divide work among
\n
"
);
...
@@ -71,15 +84,15 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
...
@@ -71,15 +84,15 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
"wrong! wrong blocksize
\n
"
);
"wrong! wrong blocksize
\n
"
);
}
}
__device__
static
constexpr
auto
GetThread
MatrixC
Lengths
()
__device__
static
constexpr
auto
Get
C
Thread
Desc_K_N_Ho_Wo
Lengths
()
{
{
return
Sequence
<
KPerThread
,
1
,
HPerThread
,
WPerThread
>
{};
return
Sequence
<
KPerThread
,
1
,
HPerThread
,
WPerThread
>
{};
}
}
__device__
static
MatrixIndex
GetBeginOfThread
MatrixC
(
index_t
thread_id
)
__device__
static
MatrixIndex
GetBeginOf
C
Thread
Desc_K_N_Ho_Wo
(
index_t
thread_id
)
{
{
constexpr
index_t
HPerBlock
=
Block
MatrixB
{}.
GetLength
(
Number
<
2
>
{}
);
constexpr
index_t
HPerBlock
=
B
Block
Desc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I2
);
constexpr
index_t
WPerBlock
=
Block
MatrixB
{}.
GetLength
(
Number
<
3
>
{}
);
constexpr
index_t
WPerBlock
=
B
Block
Desc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I3
);
constexpr
auto
num_w_threads
=
WPerBlock
/
WPerThread
;
constexpr
auto
num_w_threads
=
WPerBlock
/
WPerThread
;
constexpr
auto
num_h_threads
=
HPerBlock
/
HPerThread
;
constexpr
auto
num_h_threads
=
HPerBlock
/
HPerThread
;
...
@@ -100,42 +113,37 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
...
@@ -100,42 +113,37 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
CThreadBuffer
&
c_thread_buf
)
const
CThreadBuffer
&
c_thread_buf
)
const
{
{
static_assert
(
static_assert
(
is_same
<
remove_cvref_t
<
typename
ABlockBuffer
::
type
>
,
remove_cvref_t
<
FloatA
B
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
ABlockBuffer
::
type
>
,
remove_cvref_t
<
FloatA
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
BThreadBuffer
::
type
>
,
remove_cvref_t
<
Float
A
B
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
BThreadBuffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CThreadBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CThreadBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
"wrong! inconsistent type"
);
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
a_block_mtx
=
ABlockDesc_E1_K_E2
{};
constexpr
auto
EPerBlock
=
a_block_mtx
.
GetLength
(
I0
);
static_assert
(
EPerBlock
%
EPerThreadLoop
==
0
,
""
);
static_assert
(
KPerThread
%
KPerThreadLoop
==
0
,
""
);
// thread A buffer for GEMM
// thread A buffer for GEMM
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatA
B
,
a_thread_mtx_
.
GetElementSpaceSize
(),
true
>
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatA
,
a_thread_mtx_
.
GetElementSpaceSize
(),
true
>
a_thread_buf
;
a_thread_buf
;
constexpr
auto
threadwise_gemm
=
ThreadwiseGemmDlops_km_kn_mn_v3
<
FloatA
B
,
constexpr
auto
threadwise_gemm
=
ThreadwiseGemmDlops_km_kn_mn_v3
<
FloatA
,
Float
A
B
,
FloatB
,
FloatC
,
FloatC
,
decltype
(
a_thread_mtx_
),
decltype
(
a_thread_mtx_
),
decltype
(
b_thread_mtx_
),
decltype
(
b_thread_mtx_
),
decltype
(
c_thread_mtx_
)
>
{};
decltype
(
c_thread_mtx_
)
>
{};
static_for
<
0
,
E
PerBlock
,
EPerThreadLoop
>
{}([
&
](
auto
e_begin
)
{
static_for
<
0
,
E
1
,
EPerThreadLoop
>
{}([
&
](
auto
e_begin
)
{
static_for
<
0
,
KPerThread
,
KPerThreadLoop
>
{}([
&
](
auto
k_begin
)
{
static_for
<
0
,
KPerThread
,
KPerThreadLoop
>
{}([
&
](
auto
k_begin
)
{
a_thread_copy_
.
Run
(
a_block_mtx
,
a_thread_copy_
.
Run
(
a_block_mtx
,
make_tuple
(
e_begin
,
k_begin
),
make_tuple
(
e_begin
,
k_begin
,
I0
),
a_block_buf
,
a_block_buf
,
a_thread_mtx_
,
a_thread_mtx_
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
),
a_thread_buf
);
a_thread_buf
);
threadwise_gemm
.
Run
(
a_thread_buf
,
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
,
b_thread_buf
,
make_tuple
(
e_begin
,
I0
,
I0
,
I0
),
make_tuple
(
e_begin
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_thread_buf
,
make_tuple
(
k_begin
,
I0
,
I0
,
I0
));
make_tuple
(
k_begin
,
I0
,
I0
,
I0
));
});
});
...
@@ -145,21 +153,22 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
...
@@ -145,21 +153,22 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
template
<
typename
ABlockSliceMoveStepIdx
>
template
<
typename
ABlockSliceMoveStepIdx
>
__device__
void
MoveABlockSliceWindow
(
const
ABlockSliceMoveStepIdx
&
a_block_slice_move_step_idx
)
__device__
void
MoveABlockSliceWindow
(
const
ABlockSliceMoveStepIdx
&
a_block_slice_move_step_idx
)
{
{
a_thread_copy_
.
MoveSrcSliceWindow
(
Block
MatrixA
{},
a_block_slice_move_step_idx
);
a_thread_copy_
.
MoveSrcSliceWindow
(
A
Block
Desc_E1_K_E2
{},
a_block_slice_move_step_idx
);
}
}
private:
private:
MatrixIndex
c_thread_begin_mtx_idx_
;
MatrixIndex
c_thread_begin_mtx_idx_
;
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
using
AThreadCopy
=
FloatAB
,
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
BlockMatrixA
,
FloatB
,
decltype
(
a_thread_mtx_
),
ABlockDesc_E1_K_E2
,
Sequence
<
EPerThreadLoop
,
KPerThreadLoop
>
,
decltype
(
a_thread_mtx_
),
Sequence
<
0
,
1
>
,
Sequence
<
EPerThreadLoop
,
KPerThreadLoop
,
E2
>
,
1
,
Sequence
<
0
,
1
,
2
>
,
ThreadGemmADataPerRead_K
,
2
,
1
>
;
ThreadGemmADataPerRead_E2
,
ThreadGemmADataPerRead_E2
>
;
AThreadCopy
a_thread_copy_
;
AThreadCopy
a_thread_copy_
;
};
};
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
View file @
e6a23d8b
...
@@ -15,8 +15,8 @@ namespace ck {
...
@@ -15,8 +15,8 @@ namespace ck {
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatC
,
typename
AGridDesc_E0_E1_K
,
typename
AGridDesc_E0_E1_K
_E2
,
typename
BGridDesc_E_N_Ho_Wo
,
typename
BGridDesc_E
0_E1
_N_Ho_Wo
_E2
,
typename
CGridDesc_K_N_Ho_Wo
,
typename
CGridDesc_K_N_Ho_Wo
,
typename
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
,
typename
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
...
@@ -28,8 +28,8 @@ __global__ void
...
@@ -28,8 +28,8 @@ __global__ void
kernel_gemm_dlops_v2
(
const
FloatAB
*
__restrict__
p_a_grid
,
kernel_gemm_dlops_v2
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_E0_E1_K
a_e0_e1_k_grid_desc
,
const
AGridDesc_E0_E1_K
_E2
a_e0_e1_k_
e2_
grid_desc
,
const
BGridDesc_E_N_Ho_Wo
b_e0_e1_n_ho_wo_grid_desc
,
const
BGridDesc_E
0_E1
_N_Ho_Wo
_E2
b_e0_e1_n_ho_wo_
e2_
grid_desc
,
const
CGridDesc_K_N_Ho_Wo
c_k_n_ho_wo_grid_desc
,
const
CGridDesc_K_N_Ho_Wo
c_k_n_ho_wo_grid_desc
,
const
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
const
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
)
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
)
...
@@ -43,8 +43,8 @@ __global__ void
...
@@ -43,8 +43,8 @@ __global__ void
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_shared_block
,
p_shared_block
,
a_e0_e1_k_grid_desc
,
a_e0_e1_k_
e2_
grid_desc
,
b_e0_e1_n_ho_wo_grid_desc
,
b_e0_e1_n_ho_wo_
e2_
grid_desc
,
c_k_n_ho_wo_grid_desc
,
c_k_n_ho_wo_grid_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
...
@@ -56,8 +56,8 @@ __global__ void
...
@@ -56,8 +56,8 @@ __global__ void
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatC
,
typename
AGridDesc_E0_E1_K
,
typename
AGridDesc_E0_E1_K
_E2
,
typename
BGridDesc_E_N_Ho_Wo
,
typename
BGridDesc_E
0_E1
_N_Ho_Wo
_E2
,
typename
CGridDesc_K_N_Ho_Wo
,
typename
CGridDesc_K_N_Ho_Wo
,
typename
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
,
typename
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
...
@@ -69,18 +69,18 @@ __global__ void
...
@@ -69,18 +69,18 @@ __global__ void
kernel_gemm_dlops_v2
(
const
FloatAB
*
__restrict__
p_a_grid
,
kernel_gemm_dlops_v2
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
void
CONSTANT
*
p_a_e0_e1_k_grid_desc
,
const
void
CONSTANT
*
p_a_e0_e1_k_
e2_
grid_desc
,
const
void
CONSTANT
*
p_b_e0_e1_n_ho_wo_grid_desc
,
const
void
CONSTANT
*
p_b_e0_e1_n_ho_wo_
e2_
grid_desc
,
const
void
CONSTANT
*
p_c_k_n_ho_wo_grid_desc
,
const
void
CONSTANT
*
p_c_k_n_ho_wo_grid_desc
,
const
void
CONSTANT
*
p_c_blockid_to_k_n_ho_wo_block_cluster_adaptor
)
const
void
CONSTANT
*
p_c_blockid_to_k_n_ho_wo_block_cluster_adaptor
)
{
{
// first cast void CONSTANT void* to void*
// first cast void CONSTANT void* to void*
// second cast void* to Desc*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
// the copy constructor of tensor descriptor doesn't take address_space(4)
const
auto
a_e0_e1_k_grid_desc
=
*
reinterpret_cast
<
const
AGridDesc_E0_E1_K
*>
(
const
auto
a_e0_e1_k_
e2_
grid_desc
=
*
reinterpret_cast
<
const
AGridDesc_E0_E1_K
_E2
*>
(
cast_pointer_to_generic_address_space
(
p_a_e0_e1_k_grid_desc
));
cast_pointer_to_generic_address_space
(
p_a_e0_e1_k_
e2_
grid_desc
));
const
auto
b_e0_e1_n_ho_wo_grid_desc
=
*
reinterpret_cast
<
const
BGridDesc_E_N_Ho_Wo
*>
(
const
auto
b_e0_e1_n_ho_wo_
e2_
grid_desc
=
*
reinterpret_cast
<
const
BGridDesc_E
0_E1
_N_Ho_Wo
_E2
*>
(
cast_pointer_to_generic_address_space
(
p_b_e0_e1_n_ho_wo_grid_desc
));
cast_pointer_to_generic_address_space
(
p_b_e0_e1_n_ho_wo_
e2_
grid_desc
));
const
auto
c_k_n_ho_wo_grid_desc
=
*
reinterpret_cast
<
const
CGridDesc_K_N_Ho_Wo
*>
(
const
auto
c_k_n_ho_wo_grid_desc
=
*
reinterpret_cast
<
const
CGridDesc_K_N_Ho_Wo
*>
(
cast_pointer_to_generic_address_space
(
p_c_k_n_ho_wo_grid_desc
));
cast_pointer_to_generic_address_space
(
p_c_k_n_ho_wo_grid_desc
));
...
@@ -93,8 +93,8 @@ __global__ void
...
@@ -93,8 +93,8 @@ __global__ void
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_shared_block
,
p_shared_block
,
a_e0_e1_k_grid_desc
,
a_e0_e1_k_
e2_
grid_desc
,
b_e0_e1_n_ho_wo_grid_desc
,
b_e0_e1_n_ho_wo_
e2_
grid_desc
,
c_k_n_ho_wo_grid_desc
,
c_k_n_ho_wo_grid_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
...
@@ -106,10 +106,11 @@ template <index_t BlockSize,
...
@@ -106,10 +106,11 @@ template <index_t BlockSize,
typename
FloatAcc
,
typename
FloatAcc
,
typename
FloatC
,
typename
FloatC
,
InMemoryDataOperationEnum_t
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum_t
CGlobalMemoryDataOperation
,
typename
AGlobalDesc_E0_E1_K
,
typename
AGlobalDesc_E0_E1_K
_E2
,
typename
BGlobalDesc_E0_E1_N_Ho_Wo
,
typename
BGlobalDesc_E0_E1_N_Ho_Wo
_E2
,
typename
CGlobalDesc_K_N_Ho_Wo
,
typename
CGlobalDesc_K_N_Ho_Wo
,
index_t
E1
,
index_t
E1
,
index_t
E2
,
index_t
KPerBlock
,
index_t
KPerBlock
,
index_t
HoPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
WoPerBlock
,
...
@@ -118,13 +119,13 @@ template <index_t BlockSize,
...
@@ -118,13 +119,13 @@ template <index_t BlockSize,
index_t
HoPerThread
,
index_t
HoPerThread
,
index_t
WoPerThread
,
index_t
WoPerThread
,
index_t
EPerThread
,
index_t
EPerThread
,
typename
ABlockTransferThreadSliceLengths_E0_E1_K
,
typename
ABlockTransferThreadSliceLengths_E0_E1_K
_E2
,
typename
ABlockTransferThreadClusterLengths_E0_E1_K
,
typename
ABlockTransferThreadClusterLengths_E0_E1_K
_E2
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_
K
,
index_t
ABlockTransferDstScalarPerVector_
E2
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
typename
BBlockTransferSrcAccessOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcVectorDim
,
...
@@ -145,20 +146,20 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -145,20 +146,20 @@ struct GridwiseGemmDlops_km_kn_mn_v3
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
{
constexpr
auto
max_lds_align
=
constexpr
auto
max_lds_align
=
Number
<
ABlockTransferDstScalarPerVector_E2
>
{};
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_K
>
{},
Number
<
KPerBlock
>
{});
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_e0_e1_k_block_desc
=
make_naive_tensor_descriptor_aligned
(
constexpr
auto
a_e0_e1_k_
e2_
block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
I1
,
Number
<
E1
>
{},
Number
<
KPerBlock
>
{}),
max_lds_align
);
make_tuple
(
I1
,
Number
<
E1
>
{},
Number
<
KPerBlock
>
{}
,
Number
<
E2
>
{}
),
max_lds_align
);
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
math
::
integer_least_multiple
(
a_e0_e1_k_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
a_e0_e1_k_
e2_
block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
return
a_block_space_size
*
sizeof
(
FloatAB
);
return
a_block_space_size
*
sizeof
(
FloatAB
);
}
}
...
@@ -168,27 +169,27 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -168,27 +169,27 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const
FloatAB
*
__restrict__
p_b_global
,
const
FloatAB
*
__restrict__
p_b_global
,
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_c_global
,
FloatAB
*
__restrict__
p_shared_block
,
FloatAB
*
__restrict__
p_shared_block
,
const
AGlobalDesc_E0_E1_K
&
a_e0_e1_k_global_desc
,
const
AGlobalDesc_E0_E1_K
_E2
&
a_e0_e1_k_
e2_
global_desc
,
const
BGlobalDesc_E0_E1_N_Ho_Wo
&
b_e0_e1_n_ho_wo_global_desc
,
const
BGlobalDesc_E0_E1_N_Ho_Wo
_E2
&
b_e0_e1_n_ho_wo_
e2_
global_desc
,
const
CGlobalDesc_K_N_Ho_Wo
&
c_k_n_ho_wo_global_desc
,
const
CGlobalDesc_K_N_Ho_Wo
&
c_k_n_ho_wo_global_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_global
,
a_e0_e1_k_global_desc
.
GetElementSpaceSize
());
p_a_global
,
a_e0_e1_k_
e2_
global_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_b_global
,
b_e0_e1_n_ho_wo_global_desc
.
GetElementSpaceSize
());
p_b_global
,
b_e0_e1_n_ho_wo_
e2_
global_desc
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_global
,
c_k_n_ho_wo_global_desc
.
GetElementSpaceSize
());
p_c_global
,
c_k_n_ho_wo_global_desc
.
GetElementSpaceSize
());
static_assert
(
E1
%
EPerBlock
==
0
,
""
);
static_assert
(
E1
%
EPerBlock
==
0
,
""
);
// const auto E = a_e0_e1_k_global_desc.GetLength(I0);
// const auto E = a_e0_e1_k_
e2_
global_desc.GetLength(I0);
// const auto K = a_e0_e1_k_global_desc.GetLength(I1);
// const auto K = a_e0_e1_k_
e2_
global_desc.GetLength(I1);
// const auto N = b_e0_e1_n_ho_wo_global_desc.GetLength(I1);
// const auto N = b_e0_e1_n_ho_wo_
e2_
global_desc.GetLength(I1);
const
auto
Ho
=
b_e0_e1_n_ho_wo_global_desc
.
GetLength
(
I3
);
const
auto
Ho
=
b_e0_e1_n_ho_wo_
e2_
global_desc
.
GetLength
(
I3
);
const
auto
Wo
=
b_e0_e1_n_ho_wo_global_desc
.
GetLength
(
I4
);
const
auto
Wo
=
b_e0_e1_n_ho_wo_
e2_
global_desc
.
GetLength
(
I4
);
// divide block work by [M, N]
// divide block work by [M, N]
#if 1
#if 1
...
@@ -217,39 +218,44 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -217,39 +218,44 @@ struct GridwiseGemmDlops_km_kn_mn_v3
#endif
#endif
// lds max alignment
// lds max alignment
constexpr
auto
max_lds_align
=
constexpr
auto
max_lds_align
=
Number
<
ABlockTransferDstScalarPerVector_E2
>
{};
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_K
>
{},
Number
<
KPerBlock
>
{});
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_e0_e1_k_block_desc
=
make_naive_tensor_descriptor_aligned
(
constexpr
auto
a_e0_e1_k_e2_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
I1
>
{},
Number
<
E1
>
{},
Number
<
KPerBlock
>
{}),
max_lds_align
);
make_tuple
(
Number
<
I1
>
{},
Number
<
E1
>
{},
Number
<
KPerBlock
>
{},
Number
<
E2
>
{}),
max_lds_align
);
constexpr
auto
a_e1_k_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
EPerBlock
>
{},
Number
<
KPerBlock
>
{}),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_e1_n_ho_wo_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
constexpr
auto
b_e1_n_ho_wo_e2_block_desc
=
Number
<
EPerBlock
>
{},
Number
<
1
>
{},
Number
<
HoPerBlock
>
{},
Number
<
WoPerBlock
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
EPerBlock
>
{},
Number
<
1
>
{},
Number
<
HoPerBlock
>
{},
Number
<
WoPerBlock
>
{},
Number
<
E2
>
{}));
// c_thread_mtx definition: this is a mess
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k_n_ho_wo_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
constexpr
auto
c_k_n_ho_wo_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
Number
<
KPerThread
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
constexpr
auto
a_e1_k_e2_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
EPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
E2
>
{}),
max_lds_align
);
auto
blockwise_gemm
=
auto
blockwise_gemm
=
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
<
BlockSize
,
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
decltype
(
a_e1_k_block_desc
),
decltype
(
a_e1_k_
e2_
block_desc
),
decltype
(
b_e1_n_ho_wo_block_desc
),
decltype
(
b_e1_n_ho_wo_
e2_
block_desc
),
decltype
(
c_k_n_ho_wo_thread_desc
),
decltype
(
c_k_n_ho_wo_thread_desc
),
EPerThread
,
EPerThread
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_E2
>
{};
ABlockTransferDstScalarPerVector_K
>
{};
auto
c_thread_mtx_index
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
auto
c_thread_mtx_index
=
blockwise_gemm
.
GetBeginOfCThreadDesc_K_N_Ho_Wo
(
get_thread_local_1d_id
());
const
auto
k_thread_id
=
c_thread_mtx_index
.
k
;
const
auto
k_thread_id
=
c_thread_mtx_index
.
k
;
const
auto
ho_thread_id
=
c_thread_mtx_index
.
h
;
const
auto
ho_thread_id
=
c_thread_mtx_index
.
h
;
...
@@ -268,49 +274,53 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -268,49 +274,53 @@ struct GridwiseGemmDlops_km_kn_mn_v3
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
I1
,
E1
,
KPerBlock
>
,
Sequence
<
I1
,
E1
,
KPerBlock
,
E2
>
,
ABlockTransferThreadSliceLengths_E0_E1_K
,
ABlockTransferThreadSliceLengths_E0_E1_K
_E2
,
ABlockTransferThreadClusterLengths_E0_E1_K
,
ABlockTransferThreadClusterLengths_E0_E1_K
_E2
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
a_e0_e1_k_global_desc
),
decltype
(
a_e0_e1_k_
e2_
global_desc
),
decltype
(
a_e0_e1_k_block_desc
),
decltype
(
a_e0_e1_k_
e2_
block_desc
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
// ABlockTransferDstAccessOrder
Sequence
<
0
,
1
,
2
,
3
>
,
// ABlockTransferDstAccessOrder
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
2
,
// ABlockTransferDstVectorDim
3
,
// ABlockTransferDstVectorDim
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_
K
,
ABlockTransferDstScalarPerVector_
E2
,
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_e0_e1_k_global_desc
,
true
>
(
a_e0_e1_k_e2_global_desc
,
make_multi_index
(
0
,
0
,
k_block_data_on_global
),
make_multi_index
(
0
,
0
,
k_block_data_on_global
,
0
),
a_e0_e1_k_block_desc
,
a_e0_e1_k_e2_block_desc
,
make_multi_index
(
0
,
0
,
0
));
make_multi_index
(
0
,
0
,
0
,
0
));
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
I1
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
I1
,
0
,
0
,
0
);
constexpr
auto
b_e0_e1_n_ho_wo_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
constexpr
auto
b_e0_e1_n_ho_wo_e2_thread_desc
=
I1
,
Number
<
EPerBlock
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
EPerBlock
>
{},
auto
b_threadwise_transfer
=
Number
<
1
>
{},
ThreadwiseTensorSliceTransfer_v2
<
FloatAB
,
Number
<
HoPerThread
>
{},
FloatAB
,
Number
<
WoPerThread
>
{},
decltype
(
b_e0_e1_n_ho_wo_global_desc
),
Number
<
E2
>
{}));
decltype
(
b_e0_e1_n_ho_wo_thread_desc
),
Sequence
<
I1
,
EPerBlock
,
1
,
HoPerThread
,
WoPerThread
>
,
auto
b_threadwise_transfer
=
ThreadwiseTensorSliceTransfer_v2
<
BBlockTransferSrcAccessOrder
,
FloatAB
,
BBlockTransferSrcVectorDim
,
FloatAB
,
BBlockTransferSrcScalarPerVector
,
decltype
(
b_e0_e1_n_ho_wo_e2_global_desc
),
1
,
decltype
(
b_e0_e1_n_ho_wo_e2_thread_desc
),
true
>
(
Sequence
<
I1
,
EPerBlock
,
1
,
HoPerThread
,
WoPerThread
,
E2
>
,
b_e0_e1_n_ho_wo_global_desc
,
BBlockTransferSrcAccessOrder
,
make_multi_index
(
0
,
0
,
0
,
ho_thread_data_on_global
,
wo_thread_data_on_global
));
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
1
,
true
>
(
b_e0_e1_n_ho_wo_e2_global_desc
,
make_multi_index
(
0
,
0
,
0
,
ho_thread_data_on_global
,
wo_thread_data_on_global
,
0
));
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_shared_block
,
a_e0_e1_k_block_desc
.
GetElementSpaceSize
());
p_shared_block
,
a_e0_e1_k_
e2_
block_desc
.
GetElementSpaceSize
());
// register allocation for output
// register allocation for output
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
...
@@ -325,20 +335,20 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -325,20 +335,20 @@ struct GridwiseGemmDlops_km_kn_mn_v3
Sequence
<
KPerThread
,
1
,
HoPerThread
,
WoPerThread
>>
{}
Sequence
<
KPerThread
,
1
,
HoPerThread
,
WoPerThread
>>
{}
.
Run
(
c_k_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
FloatAcc
{
0
});
.
Run
(
c_k_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
FloatAcc
{
0
});
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
0
,
EPerBlock
,
0
,
0
,
0
);
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
0
,
EPerBlock
,
0
,
0
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_e0_e1_k_global_step_hacks
=
AGlobalStepHacks
{};
constexpr
auto
a_e0_e1_k_
e2_
global_step_hacks
=
AGlobalStepHacks
{};
constexpr
auto
b_e0_e1_n_ho_wo_global_step_hacks
=
BGlobalStepHacks
{};
constexpr
auto
b_e0_e1_n_ho_wo_
e2_
global_step_hacks
=
BGlobalStepHacks
{};
// double regsiter buffer for b
// double regsiter buffer for b
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAB
,
FloatAB
,
b_e0_e1_n_ho_wo_thread_desc
.
GetElementSpaceSize
(),
b_e0_e1_n_ho_wo_
e2_
thread_desc
.
GetElementSpaceSize
(),
true
>
true
>
b_thread_even_buf
,
b_thread_odd_buf
;
b_thread_even_buf
,
b_thread_odd_buf
;
const
auto
E0
=
b_e0_e1_n_ho_wo_global_desc
.
GetLength
(
I0
);
const
auto
E0
=
b_e0_e1_n_ho_wo_
e2_
global_desc
.
GetLength
(
I0
);
index_t
e0_block_data_begin
=
0
;
index_t
e0_block_data_begin
=
0
;
...
@@ -347,16 +357,16 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -347,16 +357,16 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// LDS double buffer: preload data
// LDS double buffer: preload data
{
{
a_blockwise_copy
.
RunRead
(
a_blockwise_copy
.
RunRead
(
a_e0_e1_k_global_desc
,
a_global_buf
,
a_e0_e1_k_global_step_hacks
);
a_e0_e1_k_
e2_
global_desc
,
a_global_buf
,
a_e0_e1_k_
e2_
global_step_hacks
);
b_threadwise_transfer
.
Run
(
b_e0_e1_n_ho_wo_global_desc
,
b_threadwise_transfer
.
Run
(
b_e0_e1_n_ho_wo_
e2_
global_desc
,
b_global_buf
,
b_global_buf
,
b_e0_e1_n_ho_wo_thread_desc
,
b_e0_e1_n_ho_wo_
e2_
thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_thread_even_buf
,
b_e0_e1_n_ho_wo_global_step_hacks
);
b_e0_e1_n_ho_wo_
e2_
global_step_hacks
);
a_blockwise_copy
.
RunWrite
(
a_e0_e1_k_block_desc
,
a_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_e0_e1_k_
e2_
block_desc
,
a_block_buf
);
}
}
__syncthreads
();
__syncthreads
();
...
@@ -370,36 +380,36 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -370,36 +380,36 @@ struct GridwiseGemmDlops_km_kn_mn_v3
do
do
{
{
// even iteration
// even iteration
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_ho_wo_global_desc
,
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_ho_wo_
e2_
global_desc
,
b_thread_slice_copy_step
);
b_thread_slice_copy_step
);
b_threadwise_transfer
.
Run
(
b_e0_e1_n_ho_wo_global_desc
,
b_threadwise_transfer
.
Run
(
b_e0_e1_n_ho_wo_
e2_
global_desc
,
b_global_buf
,
b_global_buf
,
b_e0_e1_n_ho_wo_thread_desc
,
b_e0_e1_n_ho_wo_
e2_
thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_thread_odd_buf
,
b_e0_e1_n_ho_wo_global_step_hacks
);
b_e0_e1_n_ho_wo_
e2_
global_step_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
(
make_tuple
(
EPerBlock
,
0
));
blockwise_gemm
.
MoveABlockSliceWindow
(
make_tuple
(
EPerBlock
,
0
,
0
));
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_ho_wo_global_desc
,
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_ho_wo_
e2_
global_desc
,
b_thread_slice_copy_step
);
b_thread_slice_copy_step
);
b_threadwise_transfer
.
Run
(
b_e0_e1_n_ho_wo_global_desc
,
b_threadwise_transfer
.
Run
(
b_e0_e1_n_ho_wo_
e2_
global_desc
,
b_global_buf
,
b_global_buf
,
b_e0_e1_n_ho_wo_thread_desc
,
b_e0_e1_n_ho_wo_
e2_
thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_thread_even_buf
,
b_e0_e1_n_ho_wo_global_step_hacks
);
b_e0_e1_n_ho_wo_
e2_
global_step_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
(
make_tuple
(
EPerBlock
,
0
));
blockwise_gemm
.
MoveABlockSliceWindow
(
make_tuple
(
EPerBlock
,
0
,
0
));
e1_block_data_begin
+=
2
*
EPerBlock
;
e1_block_data_begin
+=
2
*
EPerBlock
;
...
@@ -409,20 +419,20 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -409,20 +419,20 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// LDS double buffer: tail
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
{
{
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_ho_wo_global_desc
,
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_ho_wo_
e2_
global_desc
,
b_thread_slice_copy_step
);
b_thread_slice_copy_step
);
b_threadwise_transfer
.
Run
(
b_e0_e1_n_ho_wo_global_desc
,
b_threadwise_transfer
.
Run
(
b_e0_e1_n_ho_wo_
e2_
global_desc
,
b_global_buf
,
b_global_buf
,
b_e0_e1_n_ho_wo_thread_desc
,
b_e0_e1_n_ho_wo_
e2_
thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_thread_odd_buf
,
b_e0_e1_n_ho_wo_global_step_hacks
);
b_e0_e1_n_ho_wo_
e2_
global_step_hacks
);
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
(
make_tuple
(
EPerBlock
,
0
));
blockwise_gemm
.
MoveABlockSliceWindow
(
make_tuple
(
EPerBlock
,
0
,
0
));
// LDS double buffer: GEMM on last data
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
...
@@ -433,12 +443,13 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -433,12 +443,13 @@ struct GridwiseGemmDlops_km_kn_mn_v3
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
}
}
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_e0_e1_k_e2_global_desc
,
a_e0_e1_k_global_desc
,
a_block_slice_copy_step
,
AGlobalMoveSliceWindowStepHacks
{});
a_block_slice_copy_step
,
AGlobalMoveSliceWindowStepHacks
{});
blockwise_gemm
.
MoveABlockSliceWindow
(
make_tuple
(
-
(
E1
-
EPerBlock
),
0
));
blockwise_gemm
.
MoveABlockSliceWindow
(
make_tuple
(
-
(
E1
-
EPerBlock
),
0
,
0
));
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_ho_wo_global_desc
,
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e0_e1_n_ho_wo_
e2_
global_desc
,
b_thread_slice_copy_step
);
b_thread_slice_copy_step
);
e0_block_data_begin
+=
1
;
e0_block_data_begin
+=
1
;
...
...
composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
View file @
e6a23d8b
...
@@ -9,16 +9,17 @@ namespace ck {
...
@@ -9,16 +9,17 @@ namespace ck {
// C[M, N] += transpose(A[K, M]) * B[K, N]
// C[M, N] += transpose(A[K, M]) * B[K, N]
// Element of matrix can be vectorized data
// Element of matrix can be vectorized data
// Assume:
// Assume:
// 1. AThreadDesc_E_K, BThreadDesc_E_N_Ho_Wo, CThreadDesc_K_N_Ho_Wo are known at compile-time
// 1. AThreadDesc_E1_K_E2, BThreadDesc_E1_N_Ho_Wo_E2, CThreadDesc_K_N_Ho_Wo are known at
// compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template
<
typename
FloatA
,
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatB
,
typename
FloatC
,
typename
FloatC
,
typename
AThreadDesc_E_K
,
typename
AThreadDesc_E
1
_K
_E2
,
typename
BThreadDesc_E_N_Ho_Wo
,
typename
BThreadDesc_E
1
_N_Ho_Wo
_E2
,
typename
CThreadDesc_K_N_Ho_Wo
,
typename
CThreadDesc_K_N_Ho_Wo
,
typename
enable_if
<
AThreadDesc_E_K
::
IsKnownAtCompileTime
()
&&
typename
enable_if
<
AThreadDesc_E
1
_K
_E2
::
IsKnownAtCompileTime
()
&&
BThreadDesc_E_N_Ho_Wo
::
IsKnownAtCompileTime
()
&&
BThreadDesc_E
1
_N_Ho_Wo
_E2
::
IsKnownAtCompileTime
()
&&
CThreadDesc_K_N_Ho_Wo
::
IsKnownAtCompileTime
(),
CThreadDesc_K_N_Ho_Wo
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
bool
>
::
type
=
false
>
struct
ThreadwiseGemmDlops_km_kn_mn_v3
struct
ThreadwiseGemmDlops_km_kn_mn_v3
...
@@ -38,8 +39,8 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
...
@@ -38,8 +39,8 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
COriginIdx
)
COriginIdx
)
{
{
static_assert
(
AThreadDesc_E_K
::
IsKnownAtCompileTime
()
&&
static_assert
(
AThreadDesc_E
1
_K
_E2
::
IsKnownAtCompileTime
()
&&
BThreadDesc_E_N_Ho_Wo
::
IsKnownAtCompileTime
()
&&
BThreadDesc_E
1
_N_Ho_Wo
_E2
::
IsKnownAtCompileTime
()
&&
CThreadDesc_K_N_Ho_Wo
::
IsKnownAtCompileTime
(),
CThreadDesc_K_N_Ho_Wo
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
"wrong! Desc should be known at compile-time"
);
...
@@ -54,18 +55,19 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
...
@@ -54,18 +55,19 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
is_same
<
remove_cvref_t
<
typename
CBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
"wrong! inconsistent type"
);
constexpr
index_t
Vec
=
2
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
E
=
AThreadDesc_E_K
{}.
GetLength
(
I0
);
constexpr
auto
E1
=
AThreadDesc_E1_K_E2
{}.
GetLength
(
I0
);
constexpr
auto
K
=
AThreadDesc_E_K
{}.
GetLength
(
I1
);
constexpr
auto
K
=
AThreadDesc_E1_K_E2
{}.
GetLength
(
I1
);
constexpr
auto
E2
=
AThreadDesc_E1_K_E2
{}.
GetLength
(
I2
);
static_assert
(
E1
==
4
&&
E2
==
4
,
""
);
constexpr
auto
H
=
BThreadDesc_E_N_Ho_Wo
{}.
GetLength
(
I2
);
constexpr
auto
H
=
BThreadDesc_E
1
_N_Ho_Wo
_E2
{}.
GetLength
(
I2
);
constexpr
auto
W
=
BThreadDesc_E_N_Ho_Wo
{}.
GetLength
(
I3
);
constexpr
auto
W
=
BThreadDesc_E
1
_N_Ho_Wo
_E2
{}.
GetLength
(
I3
);
constexpr
auto
a_origin_idx
=
to_multi_index
(
AOriginIdx
{});
constexpr
auto
a_origin_idx
=
to_multi_index
(
AOriginIdx
{});
constexpr
auto
b_origin_idx
=
to_multi_index
(
BOriginIdx
{});
constexpr
auto
b_origin_idx
=
to_multi_index
(
BOriginIdx
{});
...
@@ -74,22 +76,23 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
...
@@ -74,22 +76,23 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
H
,
1
>
{}([
&
](
auto
h
)
{
static_for
<
0
,
H
,
1
>
{}([
&
](
auto
h
)
{
static_for
<
0
,
W
,
1
>
{}([
&
](
auto
w
)
{
static_for
<
0
,
W
,
1
>
{}([
&
](
auto
w
)
{
static_for
<
0
,
E
,
Vec
>
{}([
&
](
auto
e
)
{
static_for
<
0
,
E
1
,
1
>
{}([
&
](
auto
e
)
{
vector_type
<
FloatA
,
Vec
>
a_vec
;
vector_type
<
FloatA
,
E2
>
a_vec
;
vector_type
<
FloatB
,
Vec
>
b_vec
;
vector_type
<
FloatB
,
E2
>
b_vec
;
static_for
<
0
,
Vec
,
1
>
{}([
&
](
auto
v
)
{
static_for
<
0
,
E2
,
1
>
{}([
&
](
auto
v
)
{
constexpr
index_t
a_offset
=
AThreadDesc_E_K
{}.
CalculateOffset
(
constexpr
index_t
a_offset
=
AThreadDesc_E1_K_E2
{}.
CalculateOffset
(
a_origin_idx
+
make_tuple
(
e
+
v
,
k
));
a_origin_idx
+
make_tuple
(
e
,
k
,
v
));
constexpr
index_t
b_offset
=
BThreadDesc_E_N_Ho_Wo
{}.
CalculateOffset
(
constexpr
index_t
b_offset
=
b_origin_idx
+
make_tuple
(
e
+
v
,
0
,
h
,
w
));
BThreadDesc_E1_N_Ho_Wo_E2
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
h
,
w
,
v
));
a_vec
.
template
AsType
<
FloatA
>()(
v
)
=
a_buf
[
Number
<
a_offset
>
{}];
a_vec
.
template
AsType
<
FloatA
>()(
v
)
=
a_buf
[
Number
<
a_offset
>
{}];
b_vec
.
template
AsType
<
FloatB
>()(
v
)
=
b_buf
[
Number
<
b_offset
>
{}];
b_vec
.
template
AsType
<
FloatB
>()(
v
)
=
b_buf
[
Number
<
b_offset
>
{}];
});
});
using
a_vector_t
=
typename
vector_type
<
FloatA
,
Vec
>::
type
;
using
a_vector_t
=
typename
vector_type
<
FloatA
,
E2
>::
type
;
using
b_vector_t
=
typename
vector_type
<
FloatB
,
Vec
>::
type
;
using
b_vector_t
=
typename
vector_type
<
FloatB
,
E2
>::
type
;
constexpr
index_t
c_offset
=
CThreadDesc_K_N_Ho_Wo
{}.
CalculateOffset
(
constexpr
index_t
c_offset
=
CThreadDesc_K_N_Ho_Wo
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
h
,
w
));
c_origin_idx
+
make_tuple
(
k
,
0
,
h
,
w
));
...
...
host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
View file @
e6a23d8b
...
@@ -102,26 +102,27 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
...
@@ -102,26 +102,27 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
KPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
16
;
constexpr
index_t
HoPerBlock
=
32
;
constexpr
index_t
HoPerBlock
=
8
;
constexpr
index_t
WoPerBlock
=
8
;
constexpr
index_t
WoPerBlock
=
8
;
constexpr
index_t
E1
=
16
;
constexpr
index_t
E1
=
4
;
constexpr
index_t
EPerBlock
=
16
;
constexpr
index_t
E2
=
4
;
constexpr
index_t
EPerBlock
=
4
;
constexpr
index_t
KPerThread
=
KPerBlock
;
constexpr
index_t
KPerThread
=
4
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
EPerThread
=
EPerBlock
;
constexpr
index_t
EPerThread
=
4
;
using
ABlockTransferThreadSliceLengths_E0_E1_K
=
Sequence
<
1
,
4
,
1
>
;
using
ABlockTransferThreadSliceLengths_E0_E1_K
_E2
=
Sequence
<
1
,
1
,
1
,
4
>
;
using
ABlockTransferThreadClusterLengths_E0_E1_K
=
Sequence
<
1
,
4
,
16
>
;
using
ABlockTransferThreadClusterLengths_E0_E1_K
_E2
=
Sequence
<
1
,
4
,
16
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_E
=
4
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_E
2
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_
K
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_
E2
=
1
;
constexpr
index_t
BThreadTransferSrcScalarPerVector_E
=
4
;
constexpr
index_t
BThreadTransferSrcScalarPerVector_E
2
=
1
;
constexpr
index_t
CThreadTransferDstScalarPerVector_K
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector_K
=
1
;
#endif
#endif
constexpr
auto
conv_driver
=
constexpr
auto
conv_driver
=
...
@@ -131,6 +132,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
...
@@ -131,6 +132,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
TAcc
,
TAcc
,
TOut
,
TOut
,
E1
,
E1
,
E2
,
KPerBlock
,
KPerBlock
,
HoPerBlock
,
HoPerBlock
,
WoPerBlock
,
WoPerBlock
,
...
@@ -139,11 +141,11 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
...
@@ -139,11 +141,11 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
HoPerThread
,
HoPerThread
,
WoPerThread
,
WoPerThread
,
EPerThread
,
EPerThread
,
ABlockTransferThreadSliceLengths_E0_E1_K
,
ABlockTransferThreadSliceLengths_E0_E1_K
_E2
,
ABlockTransferThreadClusterLengths_E0_E1_K
,
ABlockTransferThreadClusterLengths_E0_E1_K
_E2
,
ABlockTransferSrcScalarPerVector_E
,
ABlockTransferSrcScalarPerVector_E
2
,
ABlockTransferDstScalarPerVector_
K
,
ABlockTransferDstScalarPerVector_
E2
,
BThreadTransferSrcScalarPerVector_E
,
BThreadTransferSrcScalarPerVector_E
2
,
CThreadTransferDstScalarPerVector_K
>
{};
CThreadTransferDstScalarPerVector_K
>
{};
const
auto
ave_time
=
const
auto
ave_time
=
...
...
host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp
View file @
e6a23d8b
...
@@ -11,6 +11,7 @@ template <ck::index_t BlockSize,
...
@@ -11,6 +11,7 @@ template <ck::index_t BlockSize,
typename
FloatAcc
,
typename
FloatAcc
,
typename
FloatC
,
typename
FloatC
,
ck
::
index_t
E1
,
ck
::
index_t
E1
,
ck
::
index_t
E2
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
HoPerBlock
,
ck
::
index_t
HoPerBlock
,
ck
::
index_t
WoPerBlock
,
ck
::
index_t
WoPerBlock
,
...
@@ -19,11 +20,11 @@ template <ck::index_t BlockSize,
...
@@ -19,11 +20,11 @@ template <ck::index_t BlockSize,
ck
::
index_t
HoPerThread
,
ck
::
index_t
HoPerThread
,
ck
::
index_t
WoPerThread
,
ck
::
index_t
WoPerThread
,
ck
::
index_t
EPerThread
,
ck
::
index_t
EPerThread
,
typename
ABlockTransferThreadSliceLengths_E
_K
,
typename
ABlockTransferThreadSliceLengths_E
0_E1_K_E2
,
typename
ABlockTransferThreadClusterLengths_E
_K
,
typename
ABlockTransferThreadClusterLengths_E
0_E1_K_E2
,
ck
::
index_t
ABlockTransferSrcScalarPerVector_E
,
ck
::
index_t
ABlockTransferSrcScalarPerVector_E
2
,
ck
::
index_t
ABlockTransferDstScalarPerVector_
K
,
ck
::
index_t
ABlockTransferDstScalarPerVector_
E2
,
ck
::
index_t
BThreadTransferSrcScalarPerVector_E
,
ck
::
index_t
BThreadTransferSrcScalarPerVector_E
2
,
ck
::
index_t
CThreadTransferDstScalarPerVector_K
>
ck
::
index_t
CThreadTransferDstScalarPerVector_K
>
struct
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad
struct
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad
{
{
...
@@ -93,7 +94,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -93,7 +94,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
<<
std
::
endl
;
<<
std
::
endl
;
const
auto
E
=
C0
*
Y
*
X
*
C1
;
const
auto
E
=
C0
*
Y
*
X
*
C1
;
const
auto
E0
=
E
/
E1
;
const
auto
E0
=
E
/
(
E1
*
E2
)
;
// weight tensor
// weight tensor
const
auto
a_e_k_grid_desc
=
transform_tensor_descriptor
(
const
auto
a_e_k_grid_desc
=
transform_tensor_descriptor
(
...
@@ -103,11 +104,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -103,11 +104,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
a_e0_e1_k_grid_desc
=
transform_tensor_descriptor
(
const
auto
a_e0_e1_k_e2_grid_desc
=
a_e_k_grid_desc
,
transform_tensor_descriptor
(
a_e_k_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
E0
,
E1
)),
make_pass_through_transform
(
K
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
E0
,
E1
,
E2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// input tensor
// input tensor
const
auto
in_n_c0_hip_wip_c1_global_desc
=
transform_tensor_descriptor
(
const
auto
in_n_c0_hip_wip_c1_global_desc
=
transform_tensor_descriptor
(
...
@@ -141,14 +143,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -141,14 +143,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
make_tuple
(
Sequence
<
1
,
2
,
4
,
6
>
{},
Sequence
<
0
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
,
6
>
{},
Sequence
<
0
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
b_e0_e1_n_ho_wo_grid_desc
=
transform_tensor_descriptor
(
const
auto
b_e0_e1_n_ho_wo_
e2_
grid_desc
=
transform_tensor_descriptor
(
b_e_n_ho_wo_grid_desc
,
b_e_n_ho_wo_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
E0
,
E1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
E0
,
E1
,
E2
)),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
Hop
),
make_pass_through_transform
(
Hop
),
make_pass_through_transform
(
Wop
)),
make_pass_through_transform
(
Wop
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
0
,
1
,
5
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
// output tensor
// output tensor
const
auto
c_k_n_hop_wop_grid_desc
=
transform_tensor_descriptor
(
const
auto
c_k_n_hop_wop_grid_desc
=
transform_tensor_descriptor
(
...
@@ -169,27 +171,33 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -169,27 +171,33 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
}
}
// hack to control index calculation when iterating over a_k_m_global tensor
// hack to control index calculation when iterating over a_k_m_global tensor
constexpr
auto
a_e0_e1_k_global_step_hacks
=
make_tuple
(
constexpr
auto
a_e0_e1_k_e2_global_step_hacks
=
make_tuple
(
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
a_e0_e1_k_global_move_slice_window_step_hack
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
a_e0_e1_k_
e2_
global_move_slice_window_step_hack
=
Sequence
<
0
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
b_e0_e1_n_ho_wo_global_step_hacks
=
make_tuple
(
constexpr
auto
b_e0_e1_n_ho_wo_
e2_
global_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
constexpr
auto
b_e0_e1_n_ho_wo_global_move_slice_window_step_hack
=
constexpr
auto
b_e0_e1_n_ho_wo_
e2_
global_move_slice_window_step_hack
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
...
@@ -211,10 +219,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -211,10 +219,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
FloatAcc
,
FloatAcc
,
FloatC
,
FloatC
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
a_e0_e1_k_grid_desc
),
decltype
(
a_e0_e1_k_
e2_
grid_desc
),
decltype
(
b_e0_e1_n_ho_wo_grid_desc
),
decltype
(
b_e0_e1_n_ho_wo_
e2_
grid_desc
),
decltype
(
c_k_n_hop_wop_grid_desc
),
decltype
(
c_k_n_hop_wop_grid_desc
),
E1
,
E1
,
E2
,
KPerBlock
,
KPerBlock
,
HoPerBlock
,
HoPerBlock
,
WoPerBlock
,
WoPerBlock
,
...
@@ -223,31 +232,31 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -223,31 +232,31 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
HoPerThread
,
HoPerThread
,
WoPerThread
,
WoPerThread
,
EPerThread
,
EPerThread
,
ABlockTransferThreadSliceLengths_E
_K
,
ABlockTransferThreadSliceLengths_E
0_E1_K_E2
,
ABlockTransferThreadClusterLengths_E
_K
,
ABlockTransferThreadClusterLengths_E
0_E1_K_E2
,
Sequence
<
2
,
0
,
1
>
,
Sequence
<
2
,
0
,
1
,
3
>
,
Sequence
<
2
,
0
,
1
>
,
Sequence
<
2
,
0
,
1
,
3
>
,
1
,
3
,
ABlockTransferSrcScalarPerVector_E
,
ABlockTransferSrcScalarPerVector_E
2
,
ABlockTransferDstScalarPerVector_
K
,
ABlockTransferDstScalarPerVector_
E2
,
false
,
// don't move back src coordinate after threadwise copy
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
0
,
2
,
3
,
4
,
1
>
,
Sequence
<
0
,
2
,
3
,
4
,
1
,
5
>
,
1
,
5
,
BThreadTransferSrcScalarPerVector_E
,
BThreadTransferSrcScalarPerVector_E
2
,
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
// MoveSrcSliceWindow() to save addr computation
Sequence
<
0
,
2
,
3
,
1
>
,
Sequence
<
2
,
3
,
1
,
0
>
,
0
,
0
,
CThreadTransferDstScalarPerVector_K
,
CThreadTransferDstScalarPerVector_K
,
decltype
(
a_e0_e1_k_global_step_hacks
),
decltype
(
a_e0_e1_k_
e2_
global_step_hacks
),
decltype
(
b_e0_e1_n_ho_wo_global_step_hacks
),
decltype
(
b_e0_e1_n_ho_wo_
e2_
global_step_hacks
),
decltype
(
c_k_n_ho_wo_global_tensor_step_hacks
),
decltype
(
c_k_n_ho_wo_global_tensor_step_hacks
),
decltype
(
a_e0_e1_k_global_move_slice_window_step_hack
),
decltype
(
a_e0_e1_k_
e2_
global_move_slice_window_step_hack
),
decltype
(
b_e0_e1_n_ho_wo_global_move_slice_window_step_hack
)
>
;
decltype
(
b_e0_e1_n_ho_wo_
e2_
global_move_slice_window_step_hack
)
>
;
using
AGridDesc_E0_E1_K
=
decltype
(
a_e0_e1_k_grid_desc
);
using
AGridDesc_E0_E1_K
_E2
=
decltype
(
a_e0_e1_k_
e2_
grid_desc
);
using
BGridDesc_E0_E1_N_Ho_Wo
=
decltype
(
b_e0_e1_n_ho_wo_grid_desc
);
using
BGridDesc_E0_E1_N_Ho_Wo
_E2
=
decltype
(
b_e0_e1_n_ho_wo_
e2_
grid_desc
);
using
CGridDesc_K_N_Ho_Wo
=
decltype
(
c_k_n_hop_wop_grid_desc
);
using
CGridDesc_K_N_Ho_Wo
=
decltype
(
c_k_n_hop_wop_grid_desc
);
const
auto
grid_size
=
(
K
/
KPerBlock
)
*
(
Hop
/
HoPerBlock
)
*
(
Wop
/
WoPerBlock
)
*
N
;
const
auto
grid_size
=
(
K
/
KPerBlock
)
*
(
Hop
/
HoPerBlock
)
*
(
Wop
/
WoPerBlock
)
*
N
;
...
@@ -276,8 +285,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -276,8 +285,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2
<
GridwiseGemm
,
kernel_gemm_dlops_v2
<
GridwiseGemm
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K
>
,
remove_reference_t
<
AGridDesc_E0_E1_K
_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo
_E2
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
true
,
true
,
...
@@ -291,8 +300,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -291,8 +300,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
a_e0_e1_k_grid_desc
,
a_e0_e1_k_
e2_
grid_desc
,
b_e0_e1_n_ho_wo_grid_desc
,
b_e0_e1_n_ho_wo_
e2_
grid_desc
,
c_k_n_hop_wop_grid_desc
,
c_k_n_hop_wop_grid_desc
,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
);
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
);
}
}
...
@@ -302,8 +311,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -302,8 +311,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2
<
GridwiseGemm
,
kernel_gemm_dlops_v2
<
GridwiseGemm
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K
>
,
remove_reference_t
<
AGridDesc_E0_E1_K
_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo
_E2
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
true
,
true
,
...
@@ -317,8 +326,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -317,8 +326,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
a_e0_e1_k_grid_desc
,
a_e0_e1_k_
e2_
grid_desc
,
b_e0_e1_n_ho_wo_grid_desc
,
b_e0_e1_n_ho_wo_
e2_
grid_desc
,
c_k_n_hop_wop_grid_desc
,
c_k_n_hop_wop_grid_desc
,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
);
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
);
}
}
...
@@ -328,8 +337,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -328,8 +337,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2
<
GridwiseGemm
,
kernel_gemm_dlops_v2
<
GridwiseGemm
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K
>
,
remove_reference_t
<
AGridDesc_E0_E1_K
_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo
_E2
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
false
,
false
,
...
@@ -343,8 +352,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -343,8 +352,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
a_e0_e1_k_grid_desc
,
a_e0_e1_k_
e2_
grid_desc
,
b_e0_e1_n_ho_wo_grid_desc
,
b_e0_e1_n_ho_wo_
e2_
grid_desc
,
c_k_n_hop_wop_grid_desc
,
c_k_n_hop_wop_grid_desc
,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
);
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
);
}
}
...
@@ -354,8 +363,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -354,8 +363,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2
<
GridwiseGemm
,
kernel_gemm_dlops_v2
<
GridwiseGemm
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K
>
,
remove_reference_t
<
AGridDesc_E0_E1_K
_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo
_E2
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
false
,
false
,
...
@@ -369,22 +378,22 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -369,22 +378,22 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
a_e0_e1_k_grid_desc
,
a_e0_e1_k_
e2_
grid_desc
,
b_e0_e1_n_ho_wo_grid_desc
,
b_e0_e1_n_ho_wo_
e2_
grid_desc
,
c_k_n_hop_wop_grid_desc
,
c_k_n_hop_wop_grid_desc
,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
);
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
);
}
}
return
ave_time
;
return
ave_time
;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem
a_e0_e1_k_grid_desc_dev_buf
(
sizeof
(
AGridDesc_E0_E1_K
));
DeviceMem
a_e0_e1_k_
e2_
grid_desc_dev_buf
(
sizeof
(
AGridDesc_E0_E1_K
_E2
));
DeviceMem
b_e0_e1_n_ho_wo_grid_desc_dev_buf
(
sizeof
(
BGridDesc_E0_E1_N_Ho_Wo
));
DeviceMem
b_e0_e1_n_ho_wo_
e2_
grid_desc_dev_buf
(
sizeof
(
BGridDesc_E0_E1_N_Ho_Wo
_E2
));
DeviceMem
c_k_n_hop_wop_grid_desc_dev_buf
(
sizeof
(
CGridDesc_K_N_Ho_Wo
));
DeviceMem
c_k_n_hop_wop_grid_desc_dev_buf
(
sizeof
(
CGridDesc_K_N_Ho_Wo
));
DeviceMem
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf
(
DeviceMem
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf
(
sizeof
(
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
));
sizeof
(
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
));
a_e0_e1_k_grid_desc_dev_buf
.
ToDevice
(
&
a_e0_e1_k_grid_desc
);
a_e0_e1_k_
e2_
grid_desc_dev_buf
.
ToDevice
(
&
a_e0_e1_k_
e2_
grid_desc
);
b_e0_e1_n_ho_wo_grid_desc_dev_buf
.
ToDevice
(
&
b_e0_e1_n_ho_wo_grid_desc
);
b_e0_e1_n_ho_wo_
e2_
grid_desc_dev_buf
.
ToDevice
(
&
b_e0_e1_n_ho_wo_
e2_
grid_desc
);
c_k_n_hop_wop_grid_desc_dev_buf
.
ToDevice
(
&
c_k_n_hop_wop_grid_desc
);
c_k_n_hop_wop_grid_desc_dev_buf
.
ToDevice
(
&
c_k_n_hop_wop_grid_desc
);
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf
.
ToDevice
(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf
.
ToDevice
(
&
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
);
&
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
);
...
@@ -397,8 +406,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -397,8 +406,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2
<
GridwiseGemm
,
kernel_gemm_dlops_v2
<
GridwiseGemm
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K
>
,
remove_reference_t
<
AGridDesc_E0_E1_K
_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo
_E2
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
true
,
true
,
...
@@ -414,9 +423,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -414,9 +423,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
a_e0_e1_k_grid_desc_dev_buf
.
GetDeviceBuffer
()),
a_e0_e1_k_
e2_
grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
b_e0_e1_n_ho_wo_grid_desc_dev_buf
.
GetDeviceBuffer
()),
b_e0_e1_n_ho_wo_
e2_
grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
c_k_n_hop_wop_grid_desc_dev_buf
.
GetDeviceBuffer
()),
c_k_n_hop_wop_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
...
@@ -428,8 +437,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -428,8 +437,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2
<
GridwiseGemm
,
kernel_gemm_dlops_v2
<
GridwiseGemm
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K
>
,
remove_reference_t
<
AGridDesc_E0_E1_K
_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo
_E2
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
true
,
true
,
...
@@ -445,9 +454,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -445,9 +454,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
a_e0_e1_k_grid_desc_dev_buf
.
GetDeviceBuffer
()),
a_e0_e1_k_
e2_
grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
b_e0_e1_n_ho_wo_grid_desc_dev_buf
.
GetDeviceBuffer
()),
b_e0_e1_n_ho_wo_
e2_
grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
c_k_n_hop_wop_grid_desc_dev_buf
.
GetDeviceBuffer
()),
c_k_n_hop_wop_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
...
@@ -459,8 +468,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -459,8 +468,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2
<
GridwiseGemm
,
kernel_gemm_dlops_v2
<
GridwiseGemm
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K
>
,
remove_reference_t
<
AGridDesc_E0_E1_K
_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo
_E2
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
false
,
false
,
...
@@ -476,9 +485,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -476,9 +485,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
a_e0_e1_k_grid_desc_dev_buf
.
GetDeviceBuffer
()),
a_e0_e1_k_
e2_
grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
b_e0_e1_n_ho_wo_grid_desc_dev_buf
.
GetDeviceBuffer
()),
b_e0_e1_n_ho_wo_
e2_
grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
c_k_n_hop_wop_grid_desc_dev_buf
.
GetDeviceBuffer
()),
c_k_n_hop_wop_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
...
@@ -490,8 +499,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -490,8 +499,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
kernel_gemm_dlops_v2
<
GridwiseGemm
,
kernel_gemm_dlops_v2
<
GridwiseGemm
,
FloatAB
,
FloatAB
,
FloatC
,
FloatC
,
remove_reference_t
<
AGridDesc_E0_E1_K
>
,
remove_reference_t
<
AGridDesc_E0_E1_K
_E2
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo
>
,
remove_reference_t
<
BGridDesc_E0_E1_N_Ho_Wo
_E2
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CGridDesc_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
remove_reference_t
<
CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo
>
,
false
,
false
,
...
@@ -507,9 +516,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
...
@@ -507,9 +516,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
a_e0_e1_k_grid_desc_dev_buf
.
GetDeviceBuffer
()),
a_e0_e1_k_
e2_
grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
b_e0_e1_n_ho_wo_grid_desc_dev_buf
.
GetDeviceBuffer
()),
b_e0_e1_n_ho_wo_
e2_
grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
c_k_n_hop_wop_grid_desc_dev_buf
.
GetDeviceBuffer
()),
c_k_n_hop_wop_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
cast_pointer_to_constant_address_space
(
...
...
script/run.sh
View file @
e6a23d8b
...
@@ -52,6 +52,7 @@ REPEAT=$6
...
@@ -52,6 +52,7 @@ REPEAT=$6
#./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1
#./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1
./host/driver_offline/conv_fwd_driver_offline
$LAYOUT
$ALGO
$VERIFY
$INIT
$LOG
$REPEAT
1 16 16 3 3 1080 1920 1 1 1 1 1 1 1 1
./host/driver_offline/conv_fwd_driver_offline
$LAYOUT
$ALGO
$VERIFY
$INIT
$LOG
$REPEAT
1 16 16 3 3 1080 1920 1 1 1 1 1 1 1 1
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1 16 16 1 1 8 8 1 1 1 1 1 1 1 1
################################################ layout algo verify init log repeat M___ N___ K___
################################################ layout algo verify init log repeat M___ N___ K___
#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024
#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment