Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bb7a8b28
"mmdet/vscode:/vscode.git/clone" did not exist on "9ace2eee23400a334ed8a7337e4f1fdfd024af63"
Commit
bb7a8b28
authored
Sep 08, 2021
by
Jing Zhang
Browse files
init
parent
f3acd251
Changes
8
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
470 additions
and
350 deletions
+470
-350
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
...rnel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
+9
-11
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
...ernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
+104
-93
composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
...nel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
+13
-10
host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
...ution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
+14
-38
host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp
...orward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp
+316
-186
host/driver_offline/src/conv_fwd_driver_offline.cpp
host/driver_offline/src/conv_fwd_driver_offline.cpp
+6
-6
script/cmake-rocm.sh
script/cmake-rocm.sh
+3
-3
script/run.sh
script/run.sh
+5
-3
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
View file @
bb7a8b28
...
@@ -7,8 +7,7 @@
...
@@ -7,8 +7,7 @@
namespace
ck
{
namespace
ck
{
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatAB
,
typename
FloatB
,
typename
FloatC
,
typename
FloatC
,
typename
BlockMatrixA
,
typename
BlockMatrixA
,
typename
BlockMatrixB
,
typename
BlockMatrixB
,
...
@@ -40,8 +39,8 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
...
@@ -40,8 +39,8 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
static
constexpr
auto
c_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
static
constexpr
auto
c_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThreadSubC
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
Number
<
KPerThreadSubC
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
B
,
FloatA
,
FloatA
B
,
BlockMatrixA
,
BlockMatrixA
,
decltype
(
a_thread_mtx_
),
decltype
(
a_thread_mtx_
),
Sequence
<
EPerThreadLoop
,
KPerThreadSubC
>
,
Sequence
<
EPerThreadLoop
,
KPerThreadSubC
>
,
...
@@ -111,8 +110,8 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
...
@@ -111,8 +110,8 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
CThreadBuffer
&
c_thread_buf
)
const
CThreadBuffer
&
c_thread_buf
)
const
{
{
static_assert
(
static_assert
(
is_same
<
remove_cvref_t
<
typename
ABlockBuffer
::
type
>
,
remove_cvref_t
<
FloatA
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
ABlockBuffer
::
type
>
,
remove_cvref_t
<
FloatA
B
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
BThreadBuffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
BThreadBuffer
::
type
>
,
remove_cvref_t
<
Float
A
B
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CThreadBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CThreadBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
"wrong! inconsistent type"
);
...
@@ -123,19 +122,18 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
...
@@ -123,19 +122,18 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
constexpr
auto
EPerBlock
=
a_block_mtx
.
GetLength
(
I0
);
constexpr
auto
EPerBlock
=
a_block_mtx
.
GetLength
(
I0
);
// HACK: fix this @Jing Zhang
// HACK: fix this @Jing Zhang
constexpr
auto
HoPerThreadSubC
=
2
;
constexpr
auto
HoPerThreadSubC
=
HPerThread
;
constexpr
auto
WoPerThreadSubC
=
2
;
constexpr
auto
WoPerThreadSubC
=
WPerThread
;
static_assert
(
KPerThread
%
KPerThreadSubC
==
0
,
""
);
static_assert
(
KPerThread
%
KPerThreadSubC
==
0
,
""
);
static_assert
(
HPerThread
%
HoPerThreadSubC
==
0
,
""
);
static_assert
(
HPerThread
%
HoPerThreadSubC
==
0
,
""
);
static_assert
(
WPerThread
%
WoPerThreadSubC
==
0
,
""
);
static_assert
(
WPerThread
%
WoPerThreadSubC
==
0
,
""
);
// thread A buffer for GEMM
// thread A buffer for GEMM
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatA
,
a_thread_mtx_
.
GetElementSpaceSize
(),
true
>
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatA
B
,
a_thread_mtx_
.
GetElementSpaceSize
(),
true
>
a_thread_buf
;
a_thread_buf
;
constexpr
auto
threadwise_gemm
=
ThreadwiseGemmDlops_km_kn_mn_v3
<
FloatA
,
constexpr
auto
threadwise_gemm
=
ThreadwiseGemmDlops_km_kn_mn_v3
<
FloatAB
,
FloatB
,
FloatC
,
FloatC
,
decltype
(
a_thread_mtx_
),
decltype
(
a_thread_mtx_
),
decltype
(
b_thread_mtx_
),
decltype
(
b_thread_mtx_
),
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
View file @
bb7a8b28
...
@@ -11,6 +11,96 @@
...
@@ -11,6 +11,96 @@
namespace
ck
{
namespace
ck
{
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AEKGridDesc
,
typename
BENHoWoGridDesc
,
typename
CKNHoWoGridDesc
,
typename
CBlockIdToKNHoWoBlockClusterAdaptor
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_dlops_v2
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AEKGridDesc
a_e_k_grid_desc
,
const
BENHoWoGridDesc
b_e_n_ho_wo_grid_desc
,
const
CKNHoWoGridDesc
c_k_n_ho_wo_grid_desc
,
const
CBlockIdToKNHoWoBlockClusterAdaptor
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
)
{
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
Run
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_block
,
a_e_k_grid_desc
,
b_e_n_ho_wo_grid_desc
,
c_k_n_ho_wo_grid_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by CONSTANT void pointer
// CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AEKGridDesc
,
typename
BENHoWoGridDesc
,
typename
CKNHoWoGridDesc
,
typename
CBlockIdToKNHoWoBlockClusterAdaptor
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_dlops_v2
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
void
CONSTANT
*
p_a_e_k_grid_desc
,
const
void
CONSTANT
*
p_b_e_n_ho_wo_grid_desc
,
const
void
CONSTANT
*
p_c_k_n_ho_wo_grid_desc
,
const
void
CONSTANT
*
p_c_blockid_to_k_n_ho_wo_block_cluster_adaptor
)
{
// first cast void CONSTANT void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const
auto
a_e_k_grid_desc
=
*
reinterpret_cast
<
const
AEKGridDesc
*>
(
cast_pointer_to_generic_address_space
(
p_a_e_k_grid_desc
));
const
auto
b_e_n_ho_wo_grid_desc
=
*
reinterpret_cast
<
const
BENHoWoGridDesc
*>
(
cast_pointer_to_generic_address_space
(
p_b_e_n_ho_wo_grid_desc
));
const
auto
c_k_n_ho_wo_grid_desc
=
*
reinterpret_cast
<
const
CKNHoWoGridDesc
*>
(
cast_pointer_to_generic_address_space
(
p_c_k_n_ho_wo_grid_desc
));
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
Run
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_block
,
a_e_k_grid_desc
,
b_e_n_ho_wo_grid_desc
,
c_k_n_ho_wo_grid_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
#endif
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatAcc
,
...
@@ -69,15 +159,15 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -69,15 +159,15 @@ struct GridwiseGemmDlops_km_kn_mn_v3
}
}
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
&
a_e_k_global_desc
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_a_global
,
const
BGlobalDesc
&
b_e_n_ho_wo_global_desc
,
const
FloatAB
*
__restrict__
p_b_global
,
const
FloatAB
*
__restrict__
p_b_global
,
const
CGlobalDesc
&
c_k_n_ho_wo_global_desc
,
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_c_global
,
FloatAB
*
__restrict__
p_shared_block
,
FloatAB
*
__restrict__
p_shared_block
,
const
AGlobalDesc
&
a_e_k_global_desc
,
const
BGlobalDesc
&
b_e_n_ho_wo_global_desc
,
const
CGlobalDesc
&
c_k_n_ho_wo_global_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -94,9 +184,9 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -94,9 +184,9 @@ struct GridwiseGemmDlops_km_kn_mn_v3
constexpr
auto
E
=
EPerBlock
*
3
*
3
;
constexpr
auto
E
=
EPerBlock
*
3
*
3
;
// const auto E = a_e_k_global_desc.GetLength(I0);
// const auto E = a_e_k_global_desc.GetLength(I0);
const
auto
K
=
a_e_k_global_desc
.
GetLength
(
I1
);
//
const auto K = a_e_k_global_desc.GetLength(I1);
const
auto
N
=
b_e_n_ho_wo_global_desc
.
GetLength
(
I1
);
//
const auto N = b_e_n_ho_wo_global_desc.GetLength(I1);
const
auto
Ho
=
b_e_n_ho_wo_global_desc
.
GetLength
(
I2
);
const
auto
Ho
=
b_e_n_ho_wo_global_desc
.
GetLength
(
I2
);
const
auto
Wo
=
b_e_n_ho_wo_global_desc
.
GetLength
(
I3
);
const
auto
Wo
=
b_e_n_ho_wo_global_desc
.
GetLength
(
I3
);
...
@@ -150,7 +240,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -150,7 +240,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3
auto
blockwise_gemm
=
auto
blockwise_gemm
=
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
<
BlockSize
,
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
decltype
(
a_e_k_block_desc
),
decltype
(
a_e_k_block_desc
),
...
@@ -245,9 +334,9 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -245,9 +334,9 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// hack to control index calculation when move slice window for A and B matrix for
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
// threadwise copy
constexpr
auto
a_e_k_global_move_slice_window_step_hack
=
AGlobalMoveSliceWindowStepHacks
{};
//
constexpr auto a_e_k_global_move_slice_window_step_hack =
constexpr
auto
b_e_n_ho_wo_g
lobal
_m
ove
_s
lice
_w
indow
_s
tep
_h
ack
=
// AG
lobal
M
ove
S
lice
W
indow
S
tep
H
ack
s{}; constexpr auto
BGlobalMoveSliceWindowStepHacks
{};
// b_e_n_ho_wo_global_move_slice_window_step_hack =
BGlobalMoveSliceWindowStepHacks{};
// double regsiter buffer for b
// double regsiter buffer for b
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
...
@@ -374,84 +463,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -374,84 +463,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3
c_k_n_ho_wo_global_tensor_step_hacks
);
c_k_n_ho_wo_global_tensor_step_hacks
);
}
}
}
}
// pass tensor descriptor by reference
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
&
a_e_k_global_desc
,
const
FloatAB
*
__restrict__
p_a_global
,
const
BGlobalDesc
&
b_e_n_ho_wo_global_desc
,
const
FloatAB
*
__restrict__
p_b_global
,
const
CGlobalDesc
&
c_k_n_ho_wo_global_desc
,
FloatC
*
__restrict__
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
{
constexpr
index_t
shared_block_size
=
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
Run
(
a_e_k_global_desc
,
p_a_global
,
b_e_n_ho_wo_global_desc
,
p_b_global
,
c_k_n_ho_wo_global_desc
,
p_c_global
,
p_shared_block
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
// pass tensor descriptors by their pointers
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
*
p_a_e_k_global_desc
,
const
FloatAB
*
__restrict__
p_a_global
,
const
BGlobalDesc
*
p_b_e_n_ho_wo_global_desc
,
const
FloatAB
*
__restrict__
p_b_global
,
const
CGlobalDesc
*
p_c_k_n_ho_wo_global_desc
,
FloatC
*
__restrict__
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
{
const
auto
a_e_k_global_desc
=
*
p_a_e_k_global_desc
;
const
auto
b_e_n_ho_wo_global_desc
=
*
p_b_e_n_ho_wo_global_desc
;
const
auto
c_k_n_ho_wo_global_desc
=
*
p_c_k_n_ho_wo_global_desc
;
Run
(
a_e_k_global_desc
,
p_a_global
,
b_e_n_ho_wo_global_desc
,
p_b_global
,
c_k_n_ho_wo_global_desc
,
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
// pass tensor descriptors by void*
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
void
*
p_a_e_k_global_desc
,
const
FloatAB
*
__restrict__
p_a_global
,
const
void
*
p_b_e_n_ho_wo_global_desc
,
const
FloatAB
*
__restrict__
p_b_global
,
const
void
*
p_c_k_n_ho_wo_global_desc
,
FloatC
*
__restrict__
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
{
const
auto
a_e_k_global_desc
=
*
reinterpret_cast
<
const
AGlobalDesc
*>
(
p_a_e_k_global_desc
);
const
auto
b_e_n_ho_wo_global_desc
=
*
reinterpret_cast
<
const
BGlobalDesc
*>
(
p_b_e_n_ho_wo_global_desc
);
const
auto
c_k_n_ho_wo_global_desc
=
*
reinterpret_cast
<
const
CGlobalDesc
*>
(
p_c_k_n_ho_wo_global_desc
);
Run
(
a_e_k_global_desc
,
p_a_global
,
b_e_n_ho_wo_global_desc
,
p_b_global
,
c_k_n_ho_wo_global_desc
,
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
};
};
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
View file @
bb7a8b28
...
@@ -11,8 +11,7 @@ namespace ck {
...
@@ -11,8 +11,7 @@ namespace ck {
// Assume:
// Assume:
// 1. ADesc, BDesc, CDesc are known at compile-time
// 1. ADesc, BDesc, CDesc are known at compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template
<
typename
FloatA
,
template
<
typename
FloatAB
,
typename
FloatB
,
typename
FloatC
,
typename
FloatC
,
typename
ADesc
,
typename
ADesc
,
typename
BDesc
,
typename
BDesc
,
...
@@ -37,6 +36,7 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
...
@@ -37,6 +36,7 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
CBuffer
&
c_buf
,
CBuffer
&
c_buf
,
COriginIdx
)
COriginIdx
)
{
{
static_assert
(
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
static_assert
(
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
CDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
"wrong! Desc should be known at compile-time"
);
...
@@ -47,8 +47,8 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
...
@@ -47,8 +47,8 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
static_assert
(
static_assert
(
is_same
<
remove_cvref_t
<
typename
ABuffer
::
type
>
,
remove_cvref_t
<
FloatA
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
ABuffer
::
type
>
,
remove_cvref_t
<
FloatA
B
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
BBuffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
BBuffer
::
type
>
,
remove_cvref_t
<
Float
A
B
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
"wrong! inconsistent type"
);
...
@@ -67,6 +67,7 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
...
@@ -67,6 +67,7 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
constexpr
index_t
a_offset
=
constexpr
index_t
a_offset
=
ADesc
{}.
CalculateOffset
(
a_origin_idx
+
make_tuple
(
e
,
k
));
ADesc
{}.
CalculateOffset
(
a_origin_idx
+
make_tuple
(
e
,
k
));
#if 0
if constexpr(H == 2 && W == 2)
if constexpr(H == 2 && W == 2)
{
{
constexpr index_t b_offset_0 =
constexpr index_t b_offset_0 =
...
@@ -128,6 +129,7 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
...
@@ -128,6 +129,7 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
c_buf(Number<c_offset_3>{}));
c_buf(Number<c_offset_3>{}));
}
}
else
else
#endif
{
{
static_for
<
0
,
H
,
1
>
{}([
&
](
auto
h
)
{
static_for
<
0
,
H
,
1
>
{}([
&
](
auto
h
)
{
static_for
<
0
,
W
,
1
>
{}([
&
](
auto
w
)
{
static_for
<
0
,
W
,
1
>
{}([
&
](
auto
w
)
{
...
@@ -137,11 +139,12 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
...
@@ -137,11 +139,12 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
h
,
w
));
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
h
,
w
));
#if 0
#if 1
c_buf(Number<c_offset>{}) += inner_product_with_conversion<FloatC>{}(
// c_buf(Number<c_offset>{}) += inner_product_with_conversion<FloatC>{}(
a_buf[Number<a_offset>{}], b_buf[Number<b_offset>{}]);
// a_buf[Number<a_offset>{}], b_buf[Number<b_offset>{}]);
c_buf
(
Number
<
c_offset
>
{})
=
a_buf
[
Number
<
a_offset
>
{}];
#else
#else
amd_assembly_
inner_product
(
a_buf
[
Number
<
a_offset
>
{}],
inner_product
(
a_buf
[
Number
<
a_offset
>
{}],
b_buf
[
Number
<
b_offset
>
{}],
b_buf
[
Number
<
b_offset
>
{}],
c_buf
(
Number
<
c_offset
>
{}));
c_buf
(
Number
<
c_offset
>
{}));
#endif
#endif
...
...
host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
View file @
bb7a8b28
...
@@ -85,8 +85,10 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
...
@@ -85,8 +85,10 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
in_n_c0_hi_wi_c1_device_buf
.
ToDevice
(
in_n_c0_hi_wi_c1
.
mData
.
data
());
in_n_c0_hi_wi_c1_device_buf
.
ToDevice
(
in_n_c0_hi_wi_c1
.
mData
.
data
());
wei_k_c0_y_x_c1_device_buf
.
ToDevice
(
wei_k_c0_y_x_c1
.
mData
.
data
());
wei_k_c0_y_x_c1_device_buf
.
ToDevice
(
wei_k_c0_y_x_c1
.
mData
.
data
());
const
auto
in_n_c0_hi_wi_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
C0
,
Hi
,
Wi
));
const
auto
in_n_c0_hi_wi_c1_desc
=
const
auto
wei_k_c0_y_x_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C0
,
Y
,
X
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
C0
,
Hi
,
Wi
,
C1
));
const
auto
wei_k_c0_y_x_c1_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C0
,
Y
,
X
,
C1
));
const
auto
out_n_k0_ho_wo_k1_desc
=
const
auto
out_n_k0_ho_wo_k1_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
));
...
@@ -96,47 +98,23 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
...
@@ -96,47 +98,23 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr
index_t
KPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
16
;
constexpr
index_t
HoPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
8
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
WoPerBlock
=
8
;
constexpr
index_t
EPerBlock
=
1
;
constexpr
index_t
EPerBlock
=
1
;
constexpr
index_t
KPerThread
=
KPerBlock
;
constexpr
index_t
KPerThread
=
KPerBlock
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
HoPerThread
=
1
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
1
;
constexpr
index_t
EPerThread
=
EPerBlock
;
using
ABlockTransferThreadSliceLengths_E_K
=
Sequence
<
3
,
1
>
;
using
ABlockTransferThreadClusterLengths_E_K
=
Sequence
<
3
*
EPerBlock
,
KPerBlock
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_E
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K
=
1
;
constexpr
index_t
BThreadTransferSrcScalarPerVector_W
=
1
;
constexpr
index_t
CThreadTransferDstScalarPerVector_W
=
16
;
static_assert
(
KPerThread
%
CThreadTransferDstScalarPerVector_W
==
0
,
""
);
#else
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
KPerBlock
=
16
;
constexpr
index_t
HoPerBlock
=
8
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
EPerBlock
=
1
;
constexpr
index_t
KPerThread
=
16
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
EPerThread
=
EPerBlock
;
constexpr
index_t
EPerThread
=
EPerBlock
;
using
ABlockTransferThreadSliceLengths_E_K
=
Sequence
<
9
,
1
>
;
using
ABlockTransferThreadSliceLengths_E_K
=
Sequence
<
9
,
1
>
;
using
ABlockTransferThreadClusterLengths_E_K
=
Sequence
<
EPerBlock
,
16
>
;
using
ABlockTransferThreadClusterLengths_E_K
=
Sequence
<
EPerBlock
,
KPerBlock
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_E
=
1
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_E
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K
=
1
;
constexpr
index_t
BThreadTransferSrcScalarPerVector_W
=
1
;
constexpr
index_t
BThreadTransferSrcScalarPerVector_W
=
1
;
constexpr
index_t
CThreadTransferDstScalarPerVector_W
=
K
1
;
constexpr
index_t
CThreadTransferDstScalarPerVector_W
=
1
;
static_assert
(
KPerThread
%
CThreadTransferDstScalarPerVector_W
==
0
,
""
);
static_assert
(
KPerThread
%
CThreadTransferDstScalarPerVector_W
==
0
,
""
);
#endif
#endif
...
@@ -148,7 +126,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
...
@@ -148,7 +126,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad
#endif
#endif
<
BlockSize
,
<
BlockSize
,
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
,
TInWei
,
TAcc
,
TAcc
,
TOut
,
TOut
,
KPerBlock
,
KPerBlock
,
...
@@ -166,17 +144,15 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
...
@@ -166,17 +144,15 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
BThreadTransferSrcScalarPerVector_W
,
BThreadTransferSrcScalarPerVector_W
,
CThreadTransferDstScalarPerVector_W
>
{};
CThreadTransferDstScalarPerVector_W
>
{};
conv_driver
.
Run
(
wei_k_c0_y_x_desc
,
conv_driver
.
Run
(
wei_k_c0_y_x_
c1_
desc
,
in_n_c0_hi_wi_desc
,
in_n_c0_hi_wi_
c1_
desc
,
out_n_k0_ho_wo_k1_desc
,
out_n_k0_ho_wo_k1_desc
,
conv_strides
,
conv_strides
,
conv_dilations
,
conv_dilations
,
in_left_pads
,
in_left_pads
,
in_right_pads
,
in_right_pads
,
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
static_cast
<
TInWei
*>
(
wei_k_c0_y_x_c1_device_buf
.
GetDeviceBuffer
()),
wei_k_c0_y_x_c1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_c0_hi_wi_c1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
in_n_c0_hi_wi_c1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_k0_ho_wo_k1_device_buf
.
GetDeviceBuffer
()));
static_cast
<
TOut
*>
(
out_n_k0_ho_wo_k1_device_buf
.
GetDeviceBuffer
()));
out_n_k0_ho_wo_k1_device_buf
.
FromDevice
(
out_n_k0_ho_wo_k1
.
mData
.
data
());
out_n_k0_ho_wo_k1_device_buf
.
FromDevice
(
out_n_k0_ho_wo_k1
.
mData
.
data
());
...
...
host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp
View file @
bb7a8b28
This diff is collapsed.
Click to expand it.
host/driver_offline/src/conv_fwd_driver_offline.cpp
View file @
bb7a8b28
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include <initializer_list>
#include <initializer_list>
#include <cstdlib>
#include <cstdlib>
#include <stdlib.h>
#include <stdlib.h>
#include <half.hpp>
//
#include <half.hpp>
#include "config.hpp"
#include "config.hpp"
#include "print.hpp"
#include "print.hpp"
#include "device.hpp"
#include "device.hpp"
...
@@ -23,9 +23,9 @@
...
@@ -23,9 +23,9 @@
#define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V6R1_NCHW 0
#define USE_CONV_FWD_V6R1_NCHW 0
#define USE_CONV_FWD_V5R1_NCHW
0
#define USE_CONV_FWD_V5R1_NCHW
1
#define USE_CONV_FWD_V4R4R2_XDL_NCHW
1
#define USE_CONV_FWD_V4R4R2_XDL_NCHW
0
#define USE_CONV_FWD_V4R4R4_XDL_NHWC
1
#define USE_CONV_FWD_V4R4R4_XDL_NHWC
0
enum
ConvForwardAlgo
enum
ConvForwardAlgo
{
{
...
@@ -126,7 +126,7 @@ int main(int argc, char* argv[])
...
@@ -126,7 +126,7 @@ int main(int argc, char* argv[])
constexpr
auto
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
I1
;
constexpr
auto
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
XEff
)
/
conv_stride_w
+
I1
;
#endif
#endif
#if
0
#if
1
using
in_data_t
=
float
;
using
in_data_t
=
float
;
using
acc_data_t
=
float
;
using
acc_data_t
=
float
;
using
out_data_t
=
float
;
using
out_data_t
=
float
;
...
@@ -352,7 +352,7 @@ int main(int argc, char* argv[])
...
@@ -352,7 +352,7 @@ int main(int argc, char* argv[])
const
auto
tmp
=
f_make_for_device_nchw
();
const
auto
tmp
=
f_make_for_device_nchw
();
device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw
<
in_data_t
,
device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw
<
in_data_t
,
16
,
8
,
acc_data_t
,
acc_data_t
,
out_data_t
>
(
tmp
[
I0
],
out_data_t
>
(
tmp
[
I0
],
tmp
[
I1
],
tmp
[
I1
],
...
...
script/cmake-rocm.sh
View file @
bb7a8b28
...
@@ -3,15 +3,15 @@ rm -f CMakeCache.txt
...
@@ -3,15 +3,15 @@ rm -f CMakeCache.txt
rm
-f
*
.cmake
rm
-f
*
.cmake
rm
-rf
CMakeFiles
rm
-rf
CMakeFiles
MY_PROJECT_SOURCE
=
../
../..
MY_PROJECT_SOURCE
=
../
MY_PROJECT_INSTALL
=
../install.dir
MY_PROJECT_INSTALL
=
../install.dir
cmake
\
cmake
\
-D
CMAKE_INSTALL_PREFIX
=
${
MY_PROJECT_INSTALL
}
\
-D
CMAKE_INSTALL_PREFIX
=
${
MY_PROJECT_INSTALL
}
\
-D
HALF_INCLUDE_DIR
=
"/root/workspace/external/half/include"
\
-D
HALF_INCLUDE_DIR
=
"/root/workspace/external/half/include"
\
-D
BUILD_DEV
=
O
N
\
-D
BUILD_DEV
=
O
FF
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
CMAKE_CXX_FLAGS
=
"-DCK_AMD_GPU_GFX
908
-O3 --amdgpu-target=gfx
908
-mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=
$PWD
"
\
-D
CMAKE_CXX_FLAGS
=
"-DCK_AMD_GPU_GFX
1030
-O3 --amdgpu-target=gfx
1030
-mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=
$PWD
"
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
...
...
script/run.sh
View file @
bb7a8b28
...
@@ -15,12 +15,12 @@
...
@@ -15,12 +15,12 @@
#rm -rf /root/_hip_binary_kernels_/
#rm -rf /root/_hip_binary_kernels_/
#rm -rf /tmp/olCompile*
#rm -rf /tmp/olCompile*
#
make -j conv_fwd_driver_offline
make
-j
conv_fwd_driver_offline
#make -j conv_bwd_driver_offline
#make -j conv_bwd_driver_offline
#make -j conv_wrw_driver_offline
#make -j conv_wrw_driver_offline
#make -j conv_fwd_driver_online
#make -j conv_fwd_driver_online
make
-j
gemm_driver_offline
#
make -j gemm_driver_offline
LAYOUT
=
$1
LAYOUT
=
$1
ALGO
=
$2
ALGO
=
$2
...
@@ -51,8 +51,10 @@ REPEAT=$6
...
@@ -51,8 +51,10 @@ REPEAT=$6
#./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1
#./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1
./host/driver_offline/conv_fwd_driver_offline
$LAYOUT
$ALGO
$VERIFY
$INIT
$LOG
$REPEAT
1 16 16 3 3 8 8 1 1 1 1 1 1 1 1
################################################ layout algo verify init log repeat M___ N___ K___
################################################ layout algo verify init log repeat M___ N___ K___
#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024
#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024
#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048
#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048
./host/driver_offline/gemm_driver_offline
$LAYOUT
$ALGO
$VERIFY
$INIT
$LOG
$REPEAT
3840 4096 4096
#
./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096
#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192
#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment