Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d5297aba
"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "3c17560daa4d9cf05b4104b1c54da88f6873f79a"
Unverified
Commit
d5297aba
authored
Oct 21, 2021
by
Chao Liu
Committed by
GitHub
Oct 21, 2021
Browse files
fix bug in gridwise gemm xdlops v2r3 (#45)
parent
c3018794
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
177 additions
and
89 deletions
+177
-89
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
+68
-54
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
...on_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
+2
-2
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
+107
-33
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
View file @
d5297aba
...
@@ -19,7 +19,8 @@ template <typename GridwiseGemm,
...
@@ -19,7 +19,8 @@ template <typename GridwiseGemm,
typename
AK0MK1GridDesc
,
typename
AK0MK1GridDesc
,
typename
BK0NK1GridDesc
,
typename
BK0NK1GridDesc
,
typename
CM0N0M1N1M2M3M4N2GridDesc
,
typename
CM0N0M1N1M2M3M4N2GridDesc
,
typename
CBlockClusterAdaptor
>
typename
CBlockClusterAdaptor
,
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
@@ -37,14 +38,14 @@ __global__ void
...
@@ -37,14 +38,14 @@ __global__ void
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
Run
(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>
(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_shared_block
,
p_shared_block
,
a_k0_m_k1_grid_desc
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
c_block_cluster_adaptor
);
}
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
...
@@ -81,14 +82,14 @@ __global__ void
...
@@ -81,14 +82,14 @@ __global__ void
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
Run
(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>
(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_shared_block
,
p_shared_block
,
a_k0_m_k1_grid_desc
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
c_block_cluster_adaptor
);
}
}
#endif
#endif
...
@@ -102,7 +103,7 @@ template <index_t BlockSize,
...
@@ -102,7 +103,7 @@ template <index_t BlockSize,
typename
CMNGridDesc
,
typename
CMNGridDesc
,
index_t
MPerBlock
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
K
0
PerBlock
,
index_t
MPerXDL
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
NPerXDL
,
index_t
K1Value
,
index_t
K1Value
,
...
@@ -158,13 +159,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -158,13 +159,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if
constexpr
(
ABlockLdsExtraM
)
if
constexpr
(
ABlockLdsExtraM
)
{
{
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
}
else
else
{
{
return
make_naive_tensor_descriptor_aligned
(
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}
}();
}();
...
@@ -173,13 +174,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -173,13 +174,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if
constexpr
(
BBlockLdsExtraN
)
if
constexpr
(
BBlockLdsExtraN
)
{
{
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
}
else
else
{
{
return
make_naive_tensor_descriptor_aligned
(
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}
}();
}();
...
@@ -217,7 +218,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -217,7 +218,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
K1
==
b_k0_n_k1_grid_desc
.
GetLength
(
I2
)))
K1
==
b_k0_n_k1_grid_desc
.
GetLength
(
I2
)))
return
false
;
return
false
;
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
KPerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K
0
PerBlock
==
0
))
return
false
;
return
false
;
// check M01, N01
// check M01, N01
...
@@ -245,6 +246,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -245,6 +246,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return
grid_size
;
return
grid_size
;
}
}
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
const
bool
has_main_k0_block_loop
=
(
K0
/
K0PerBlock
)
>
1
;
return
has_main_k0_block_loop
;
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
{
{
...
@@ -255,13 +263,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -255,13 +263,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if
constexpr
(
ABlockLdsExtraM
)
if
constexpr
(
ABlockLdsExtraM
)
{
{
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
}
else
else
{
{
return
make_naive_tensor_descriptor_aligned
(
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}
}();
}();
...
@@ -270,13 +278,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -270,13 +278,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if
constexpr
(
BBlockLdsExtraN
)
if
constexpr
(
BBlockLdsExtraN
)
{
{
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
}
else
else
{
{
return
make_naive_tensor_descriptor_aligned
(
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}
}();
}();
...
@@ -334,6 +342,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -334,6 +342,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
CMNGridDesc
{}));
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
CMNGridDesc
{}));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CMNGridDesc
{},
1
,
1
));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CMNGridDesc
{},
1
,
1
));
template
<
bool
HasMainKBlockLoop
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
...
@@ -371,13 +380,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -371,13 +380,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if
constexpr
(
ABlockLdsExtraM
)
if
constexpr
(
ABlockLdsExtraM
)
{
{
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
}
else
else
{
{
return
make_naive_tensor_descriptor_aligned
(
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}
}();
}();
...
@@ -386,13 +395,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -386,13 +395,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if
constexpr
(
BBlockLdsExtraN
)
if
constexpr
(
BBlockLdsExtraN
)
{
{
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
}
else
else
{
{
return
make_naive_tensor_descriptor_aligned
(
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
make_tuple
(
Number
<
K
0
PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}
}();
}();
...
@@ -400,7 +409,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -400,7 +409,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
KPerBlock
,
MPerBlock
,
K1
>
,
Sequence
<
K
0
PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
...
@@ -426,7 +435,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -426,7 +435,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
auto
b_blockwise_copy
=
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
BlockwiseTensorSliceTransfer_v4
<
BlockSize
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
KPerBlock
,
NPerBlock
,
K1
>
,
Sequence
<
K
0
PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
...
@@ -450,8 +459,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -450,8 +459,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// GEMM definition
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// a_mtx[K
0
PerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlock, NPerBlock] is in LDS
// b_mtx[K
0
PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// register
// sanity check
// sanity check
...
@@ -477,8 +486,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -477,8 +486,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
FloatAB
*
p_a_block
=
p_shared_block
;
FloatAB
*
p_a_block
=
p_shared_block
;
FloatAB
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
FloatAB
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K
0
PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K
0
PerBlock
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
AGridStepHacks
{};
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
AGridStepHacks
{};
...
@@ -504,32 +513,37 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -504,32 +513,37 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
}
// main body
// main body
index_t
k_block_data_begin
=
0
;
index_t
k
0
_block_data_begin
=
0
;
do
if
constexpr
(
HasMainKBlockLoop
)
{
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m_k1_grid_desc
,
do
a_block_slice_copy_step
,
{
a_k0_m_k1_grid_move_slice_window_step_hack
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m_k1_grid_desc
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n_k1_grid_desc
,
a_block_slice_copy_step
,
b_block_slice_copy_step
,
a_k0_m_k1_grid_move_slice_window_step_hack
);
b_k0_n_k1_grid_move_slice_window_step_hack
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n_k1_grid_desc
,
b_block_slice_copy_step
,
b_k0_n_k1_grid_move_slice_window_step_hack
);
a_blockwise_copy
.
RunRead
(
a_k0_m_k1_grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
a_blockwise_copy
.
RunRead
(
a_k0_m_k1_grid_desc
,
a_grid_buf
,
a_k0_m_k1_grid_step_hacks
);
block_sync_lds
();
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_k0_n_k1_grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
b_blockwise_copy
.
RunRead
(
b_k0_n_k1_grid_desc
,
b_grid_buf
,
b_k0_n_k1_grid_step_hacks
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_k0_m_k1_block_desc
,
a_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_k0_m_k1_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_k0_n_k1_block_desc
,
b_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_k0_n_k1_block_desc
,
b_block_buf
);
k_block_data_begin
+=
KPerBlock
;
k0_block_data_begin
+=
K0PerBlock
;
}
while
(
k_block_data_begin
<
(
K0
-
KPerBlock
));
}
while
(
k0_block_data_begin
<
(
K0
-
K0PerBlock
));
}
// tail
// tail
{
{
...
...
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
View file @
d5297aba
...
@@ -160,7 +160,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
...
@@ -160,7 +160,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif
0
#elif
1
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -188,7 +188,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
...
@@ -188,7 +188,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif
1
#elif
0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
...
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
View file @
d5297aba
...
@@ -148,28 +148,61 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
...
@@ -148,28 +148,61 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
c_m_n_grid_desc
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
c_m_n_grid_desc
);
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
const
auto
K0
=
a_k0_m_k1_grid_desc
.
GetLength
(
I0
);
FloatAB
,
FloatC
,
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
remove_reference_t
<
AK0MK1GridDesc
>
,
remove_reference_t
<
BK0NK1GridDesc
>
,
float
ave_time
=
0
;
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>>
;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float
ave_time
=
launch_and_time_kernel
(
kernel
,
if
(
has_main_k0_block_loop
)
nrepeat
,
{
dim3
(
grid_size
),
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
dim3
(
BlockSize
),
FloatAB
,
0
,
FloatC
,
p_a_grid
,
remove_reference_t
<
AK0MK1GridDesc
>
,
p_b_grid
,
remove_reference_t
<
BK0NK1GridDesc
>
,
p_c_grid
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
a_k0_m_k1_grid_desc
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
b_k0_n_k1_grid_desc
,
true
>
;
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AK0MK1GridDesc
>
,
remove_reference_t
<
BK0NK1GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
c_block_cluster_adaptor
);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem
a_k0_m_k1_grid_desc_dev_buf
(
sizeof
(
AK0MK1GridDesc
));
DeviceMem
a_k0_m_k1_grid_desc_dev_buf
(
sizeof
(
AK0MK1GridDesc
));
DeviceMem
b_k0_n_k1_grid_desc_dev_buf
(
sizeof
(
BK0NK1GridDesc
));
DeviceMem
b_k0_n_k1_grid_desc_dev_buf
(
sizeof
(
BK0NK1GridDesc
));
...
@@ -181,20 +214,61 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
...
@@ -181,20 +214,61 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
ToDevice
(
&
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
ToDevice
(
&
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
c_block_cluster_adaptor_dev_buf
.
ToDevice
(
&
c_block_cluster_adaptor
);
c_block_cluster_adaptor_dev_buf
.
ToDevice
(
&
c_block_cluster_adaptor
);
float
ave_time
=
launch_and_time_kernel
(
if
(
has_main_k0_block_loop
)
kernel
,
{
nrepeat
,
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
dim3
(
grid_size
),
FloatAB
,
dim3
(
BlockSize
),
FloatC
,
0
,
remove_reference_t
<
AK0MK1GridDesc
>
,
p_a_grid
,
remove_reference_t
<
BK0NK1GridDesc
>
,
p_b_grid
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
p_c_grid
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
cast_pointer_to_constant_address_space
(
a_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
true
>
;
cast_pointer_to_constant_address_space
(
b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
ave_time
=
launch_and_time_kernel
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
kernel
,
cast_pointer_to_constant_address_space
(
c_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
FloatAB
,
FloatC
,
remove_reference_t
<
AK0MK1GridDesc
>
,
remove_reference_t
<
BK0NK1GridDesc
>
,
remove_reference_t
<
CM0N0M1N1M2M3M4N2GridDesc
>
,
remove_reference_t
<
CBlockClusterAdaptor
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
}
#endif
#endif
return
ave_time
;
return
ave_time
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment