Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
646fcc26
Unverified
Commit
646fcc26
authored
Oct 27, 2021
by
Chao Liu
Committed by
GitHub
Oct 27, 2021
Browse files
Merge pull request #47 from ROCmSoftwarePlatform/develop
Merge develop into master
parents
38a90b6e
6014185a
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
365 additions
and
200 deletions
+365
-200
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
+68
-54
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
+75
-61
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp
...ght_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp
+3
-5
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
...ght_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
+3
-5
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp
...ght_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp
+5
-7
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
...on_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
+2
-2
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
+107
-33
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
+102
-33
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
View file @
646fcc26
...
...
@@ -19,7 +19,8 @@ template <typename GridwiseGemm,
typename
AK0MK1GridDesc
,
typename
BK0NK1GridDesc
,
typename
CM0N0M1N1M2M3M4N2GridDesc
,
typename
CBlockClusterAdaptor
>
typename
CBlockClusterAdaptor
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
...
@@ -37,14 +38,14 @@ __global__ void
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
Run
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_block
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_block
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template
<
typename
GridwiseGemm
,
...
...
@@ -81,14 +82,14 @@ __global__ void
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
Run
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_block
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_block
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
}
#endif
...
...
@@ -102,7 +103,7 @@ template <index_t BlockSize,
typename
CMNGridDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
K
0
PerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
K1Value
,
...
...
@@ -158,13 +159,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
...
...
@@ -173,13 +174,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
...
...
@@ -217,7 +218,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
K1
==
b_k0_n_k1_grid_desc
.
GetLength
(
I2
)))
return
false
;
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
KPerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K
0
PerBlock
==
0
))
return
false
;
// check M01, N01
...
...
@@ -245,6 +246,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
const
bool
has_main_k0_block_loop
=
(
K0
/
K0PerBlock
)
>
1
;
return
has_main_k0_block_loop
;
}
__host__
__device__
static
constexpr
auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
{
...
...
@@ -255,13 +263,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
...
...
@@ -270,13 +278,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
...
...
@@ -334,6 +342,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
CMNGridDesc
{}));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CMNGridDesc
{},
1
,
1
));
template
<
bool
HasMainKBlockLoop
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
...
...
@@ -371,13 +380,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
...
...
@@ -386,13 +395,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
...
...
@@ -400,7 +409,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
KPerBlock
,
MPerBlock
,
K1
>
,
Sequence
<
K
0
PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
...
...
@@ -426,7 +435,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
KPerBlock
,
NPerBlock
,
K1
>
,
Sequence
<
K
0
PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
...
...
@@ -450,8 +459,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlock, NPerBlock] is in LDS
// a_mtx[K
0
PerBlock, MPerBlock] is in LDS
// b_mtx[K
0
PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
...
...
@@ -477,8 +486,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
FloatAB
*
p_a_block
=
p_shared_block
;
FloatAB
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K
0
PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K
0
PerBlock
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
AGridStepHacks
{};
...
...
@@ -504,32 +513,37 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
// main body
index_t
k_block_data_begin
=
0
;
index_t
k
0
_block_data_begin
=
0
;
do
if
constexpr
(
HasMainKBlockLoop
)
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m_k1_grid_desc
,
a_block_slice_copy_step
,
a_k0_m_k1_grid_move_slice_window_step_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n_k1_grid_desc
,
b_block_slice_copy_step
,
b_k0_n_k1_grid_move_slice_window_step_hack
);
do
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m_k1_grid_desc
,
a_block_slice_copy_step
,
a_k0_m_k1_grid_move_slice_window_step_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n_k1_grid_desc
,
b_block_slice_copy_step
,
b_k0_n_k1_grid_move_slice_window_step_hack
);
a_blockwise_copy
.
RunRead
(
a_k0_m_k1_grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
a_blockwise_copy
.
RunRead
(
a_k0_m_k1_grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
block_sync_lds
();
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_k0_n_k1_grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
b_blockwise_copy
.
RunRead
(
b_k0_n_k1_grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_k0_m_k1_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_k0_n_k1_block_desc
,
b_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_k0_m_k1_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_k0_n_k1_block_desc
,
b_block_buf
);
k_block_data_begin
+=
KPerBlock
;
}
while
(
k_block_data_begin
<
(
K0
-
KPerBlock
));
k0_block_data_begin
+=
K0PerBlock
;
}
while
(
k0_block_data_begin
<
(
K0
-
K0PerBlock
));
}
// tail
{
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
View file @
646fcc26
...
...
@@ -19,7 +19,8 @@ template <typename GridwiseGemm,
typename
ABK0MK1GridDesc
,
typename
BBK0NK1GridDesc
,
typename
CM0N0M1N1M2M3M4N2GridDesc
,
typename
CBlockClusterAdaptor
>
typename
CBlockClusterAdaptor
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
...
@@ -37,14 +38,14 @@ __global__ void
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
Run
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_block
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_block
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template
<
typename
GridwiseGemm
,
...
...
@@ -53,7 +54,8 @@ template <typename GridwiseGemm,
typename
ABK0MK1GridDesc
,
typename
BBK0NK1GridDesc
,
typename
CM0N0M1N1M2M3M4N2GridDesc
,
typename
CBlockClusterAdaptor
>
typename
CBlockClusterAdaptor
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
...
@@ -81,14 +83,14 @@ __global__ void
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
Run
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_block
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_block
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
}
#endif
...
...
@@ -102,7 +104,7 @@ template <index_t BlockSize,
typename
CMNGridDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
K
0
PerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
K1Value
,
...
...
@@ -158,13 +160,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
...
...
@@ -173,13 +175,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
...
...
@@ -220,7 +222,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
KBatch
==
b_b_k0_n_k1_grid_desc
.
GetLength
(
I0
)))
return
false
;
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
KPerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K
0
PerBlock
==
0
))
return
false
;
// check M01, N01
...
...
@@ -248,6 +250,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
const
bool
has_main_k0_block_loop
=
K0
>
K0PerBlock
;
return
has_main_k0_block_loop
;
}
__host__
__device__
static
constexpr
auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
{
...
...
@@ -258,13 +267,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
...
...
@@ -273,13 +282,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
...
...
@@ -338,6 +347,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
CMNGridDesc
{}));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CMNGridDesc
{},
1
,
1
,
1
));
template
<
bool
HasMainKBlockLoop
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
...
...
@@ -376,13 +386,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
...
...
@@ -390,8 +400,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
KPerBlock
>
{}
*
Number
<
MPerBlock
+
1
>
{}
*
K1
,
make_tuple
(
Number
<
1
>
{},
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{}
*
Number
<
MPerBlock
+
1
>
{}
*
K1
,
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
...
...
@@ -399,7 +409,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
1
>
{},
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
1
>
{},
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
...
...
@@ -408,13 +418,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
...
...
@@ -422,8 +432,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
KPerBlock
>
{}
*
Number
<
NPerBlock
+
1
>
{}
*
K1
,
make_tuple
(
Number
<
1
>
{},
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{}
*
Number
<
NPerBlock
+
1
>
{}
*
K1
,
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
...
...
@@ -431,7 +441,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
1
>
{},
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
1
>
{},
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
...
...
@@ -439,7 +449,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
1
,
KPerBlock
,
MPerBlock
,
K1
>
,
Sequence
<
1
,
K
0
PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
...
...
@@ -466,7 +476,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
1
,
KPerBlock
,
NPerBlock
,
K1
>
,
Sequence
<
1
,
K
0
PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
...
...
@@ -491,8 +501,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlock, NPerBlock] is in LDS
// a_mtx[K
0
PerBlock, MPerBlock] is in LDS
// b_mtx[K
0
PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
...
...
@@ -518,8 +528,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
FloatAB
*
p_a_block
=
p_shared_block
;
FloatAB
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
KPerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
KPerBlock
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
K
0
PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
K
0
PerBlock
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
AGridStepHacks
{};
...
...
@@ -546,31 +556,35 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
// main body
index_t
k_block_data_begin
=
0
;
do
if
constexpr
(
HasMainKBlockLoop
)
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_b_k0_m_k1_grid_desc
,
a_block_slice_copy_step
,
a_k0_m_k1_grid_move_slice_window_step_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_b_k0_n_k1_grid_desc
,
b_block_slice_copy_step
,
b_k0_n_k1_grid_move_slice_window_step_hack
);
do
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_b_k0_m_k1_grid_desc
,
a_block_slice_copy_step
,
a_k0_m_k1_grid_move_slice_window_step_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_b_k0_n_k1_grid_desc
,
b_block_slice_copy_step
,
b_k0_n_k1_grid_move_slice_window_step_hack
);
a_blockwise_copy
.
RunRead
(
a_b_k0_m_k1_grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
a_blockwise_copy
.
RunRead
(
a_b_k0_m_k1_grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
block_sync_lds
();
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_b_k0_n_k1_grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
b_blockwise_copy
.
RunRead
(
b_b_k0_n_k1_grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_b_k0_m_k1_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_b_k0_n_k1_block_desc
,
b_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_b_k0_m_k1_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_b_k0_n_k1_block_desc
,
b_block_buf
);
k_block_data_begin
+=
KPerBlock
;
}
while
(
k_block_data_begin
<
(
K0
-
KPerBlock
));
k_block_data_begin
+=
K0PerBlock
;
}
while
(
k_block_data_begin
<
(
K0
-
K0PerBlock
));
}
// tail
{
...
...
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp
View file @
646fcc26
...
...
@@ -95,13 +95,11 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_
const
auto
GemmN
=
Y
*
X
*
C
;
const
auto
GemmKTotal
=
N
*
Ho
*
Wo
;
const
auto
GemmK
=
GemmKTotal
/
GemmK1
;
const
auto
GridMN
=
GemmM
*
GemmN
/
(
GemmMPerBlock
*
GemmNPerBlock
);
const
index_t
GemmKBatch
=
std
::
max
(
desired_grid_size
/
GridMN
,
1
);
const
index_t
BatchLen
=
std
::
ceil
(
GemmK
*
1.0
/
(
GemmKPerBlock
*
GemmKBatch
));
const
index_t
GemmK0
=
Batch
Len
*
GemmKPerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1
*
GemmKPerBlock
*
GemmK
Batch
)
*
GemmKPerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1
;
std
::
cout
<<
"GemmKTotal: "
<<
GemmKTotal
<<
" GrideSizeMN: "
<<
GridMN
<<
" GemmKBatch: "
<<
GemmKBatch
<<
" GemmK0: "
<<
GemmK0
<<
" gemmKPad: "
<<
GemmKPad
...
...
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp
View file @
646fcc26
...
...
@@ -123,13 +123,11 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
const
auto
GemmN
=
K
;
const
auto
GemmKTotal
=
N
*
Ho
*
Wo
;
const
auto
GemmK
=
GemmKTotal
/
GemmK1
;
const
auto
GridMN
=
GemmM
*
GemmN
/
(
GemmMPerBlock
*
GemmNPerBlock
);
const
index_t
GemmKBatch
=
std
::
max
(
desired_grid_size
/
GridMN
,
1
);
const
index_t
BatchLen
=
std
::
ceil
(
GemmK
*
1.0
/
(
GemmKPerBlock
*
GemmKBatch
));
const
index_t
GemmK0
=
Batch
Len
*
GemmKPerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1
*
GemmKPerBlock
*
GemmK
Batch
)
*
GemmKPerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1
;
std
::
cout
<<
"GemmKTotal: "
<<
GemmKTotal
<<
" GrideSizeMN: "
<<
GridMN
<<
" GemmKBatch: "
<<
GemmKBatch
<<
" GemmK0: "
<<
GemmK0
<<
" gemmKPad: "
<<
GemmKPad
...
...
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp
View file @
646fcc26
...
...
@@ -107,8 +107,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
2
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif
0
// [M, N, K0, K1] = [128, 128, 4, 4], C 64, for fp32
#elif
1
// [M, N, K0, K1] = [128, 128, 4, 4], C 64, for fp32
and fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
...
...
@@ -291,13 +291,11 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_
const
auto
GemmN
=
Y
*
X
*
C
;
const
auto
GemmKTotal
=
N
*
Ho
*
Wo
;
const
auto
GemmK
=
GemmKTotal
/
GemmK1
;
const
auto
GridMN
=
GemmM
*
GemmN
/
(
GemmMPerBlock
*
GemmNPerBlock
);
const
index_t
GemmKBatch
=
std
::
max
(
desired_grid_size
/
GridMN
,
1
);
const
index_t
BatchLen
=
std
::
ceil
(
GemmK
*
1.0
/
(
GemmKPerBlock
*
GemmKBatch
));
const
index_t
GemmK0
=
Batch
Len
*
GemmKPerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1
*
GemmKPerBlock
*
GemmK
Batch
)
*
GemmKPerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1
;
std
::
cout
<<
"GemmKTotal: "
<<
GemmKTotal
<<
" GrideSizeMN: "
<<
GridMN
<<
" GemmKBatch: "
<<
GemmKBatch
<<
" GemmK0: "
<<
GemmK0
<<
" gemmKPad: "
<<
GemmKPad
...
...
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
View file @
646fcc26
...
...
@@ -160,7 +160,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif
0
#elif
1
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
256
;
...
...
@@ -188,7 +188,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif
1
#elif
0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
256
;
...
...
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
View file @
646fcc26
...
...
@@ -148,28 +148,61 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
c_m_n_grid_desc
);
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AK0MK1GridDesc
>
,
remove_reference_t
<
BK0NK1GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>>
;
const
auto
K0
=
a_k0_m_k1_grid_desc
.
GetLength
(
I0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
float
ave_time
=
0
;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AK0MK1GridDesc
>
,
remove_reference_t
<
BK0NK1GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AK0MK1GridDesc
>
,
remove_reference_t
<
BK0NK1GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem
a_k0_m_k1_grid_desc_dev_buf
(
sizeof
(
AK0MK1GridDesc
));
DeviceMem
b_k0_n_k1_grid_desc_dev_buf
(
sizeof
(
BK0NK1GridDesc
));
...
...
@@ -181,20 +214,61 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
ToDevice
(
&
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
c_block_cluster_adaptor_dev_buf
.
ToDevice
(
&
c_block_cluster_adaptor
);
float
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AK0MK1GridDesc
>
,
remove_reference_t
<
BK0NK1GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AK0MK1GridDesc
>
,
remove_reference_t
<
BK0NK1GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
}
#endif
return
ave_time
;
}
...
...
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
View file @
646fcc26
...
...
@@ -156,27 +156,58 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
std
::
cout
<<
"gridSize : "
<<
grid_size
<<
std
::
endl
;
}
const
auto
kernel
=
kernel_gemm_xdlops_v2r4
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
ABK0MK1GridDesc
>
,
remove_reference_t
<
BBK0NK1GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>>
;
const
auto
K0
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
float
ave_time
=
0
;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
ABK0MK1GridDesc
>
,
remove_reference_t
<
BBK0NK1GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
ABK0MK1GridDesc
>
,
remove_reference_t
<
BBK0NK1GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem
a_b_k0_m_k1_grid_desc_dev_buf
(
sizeof
(
ABK0MK1GridDesc
));
...
...
@@ -189,20 +220,58 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
ToDevice
(
&
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
c_block_cluster_adaptor_dev_buf
.
ToDevice
(
&
c_block_cluster_adaptor
);
float
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_b_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
ABK0MK1GridDesc
>
,
remove_reference_t
<
BBK0NK1GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_b_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
ABK0MK1GridDesc
>
,
remove_reference_t
<
BBK0NK1GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_b_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
#endif
return
ave_time
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment