Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f27232c5
Commit
f27232c5
authored
May 17, 2023
by
Po-Yen, Chen
Browse files
Add macro to switch descriptor opt
parent
a3abce47
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
0 deletions
+20
-0
include/ck/ck.hpp
include/ck/ck.hpp
+2
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+9
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+9
-0
No files found.
include/ck/ck.hpp
View file @
f27232c5
...
@@ -12,6 +12,8 @@
...
@@ -12,6 +12,8 @@
#define ENABLE_DUMP_CLOCK 1
#define ENABLE_DUMP_CLOCK 1
#define ENABLE_DESC_OPT 1
// constant address space for kernel parameter
// constant address space for kernel parameter
// https://llvm.org/docs/AMDGPUUsage.html#address-spaces
// https://llvm.org/docs/AMDGPUUsage.html#address-spaces
#define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4)))
#define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4)))
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
f27232c5
...
@@ -692,12 +692,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -692,12 +692,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
#endif
#endif
#if ENABLE_DESC_OPT
const
auto
a_grid_desc_ak0_m_ak1
=
readfirstlane
(
MakeAGridDescriptor_AK0_M_AK1
(
const
auto
a_grid_desc_ak0_m_ak1
=
readfirstlane
(
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
));
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
));
const
auto
b_grid_desc_bk0_n_bk1
=
readfirstlane
(
MakeBGridDescriptor_BK0_N_BK1
(
const
auto
b_grid_desc_bk0_n_bk1
=
readfirstlane
(
MakeBGridDescriptor_BK0_N_BK1
(
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
));
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
));
const
auto
c_grid_desc_m_n
=
readfirstlane
(
MakeCGridDescriptor_M_N
(
const
auto
c_grid_desc_m_n
=
readfirstlane
(
MakeCGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
));
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
));
#else
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
#endif
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
f27232c5
...
@@ -510,12 +510,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -510,12 +510,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
#endif
#endif
#if ENABLE_DESC_OPT
const
auto
a_grid_desc_k0_m_k1
=
readfirstlane
(
const
auto
a_grid_desc_k0_m_k1
=
readfirstlane
(
MakeAGridDescriptor_K0_M_K1
(
karg
.
M
,
karg
.
MPadded
,
karg
.
K
,
karg
.
K0
,
karg
.
StrideA
));
MakeAGridDescriptor_K0_M_K1
(
karg
.
M
,
karg
.
MPadded
,
karg
.
K
,
karg
.
K0
,
karg
.
StrideA
));
const
auto
b_grid_desc_k0_n_k1
=
readfirstlane
(
const
auto
b_grid_desc_k0_n_k1
=
readfirstlane
(
MakeBGridDescriptor_K0_N_K1
(
karg
.
K
,
karg
.
N
,
karg
.
NPadded
,
karg
.
K0
,
karg
.
StrideB
));
MakeBGridDescriptor_K0_N_K1
(
karg
.
K
,
karg
.
N
,
karg
.
NPadded
,
karg
.
K0
,
karg
.
StrideB
));
const
auto
c_grid_desc_m_n
=
readfirstlane
(
const
auto
c_grid_desc_m_n
=
readfirstlane
(
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
MPadded
,
karg
.
N
,
karg
.
NPadded
,
karg
.
StrideC
));
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
MPadded
,
karg
.
N
,
karg
.
NPadded
,
karg
.
StrideC
));
#else
const
auto
a_grid_desc_k0_m_k1
=
MakeAGridDescriptor_K0_M_K1
(
karg
.
M
,
karg
.
MPadded
,
karg
.
K
,
karg
.
K0
,
karg
.
StrideA
);
const
auto
b_grid_desc_k0_n_k1
=
MakeBGridDescriptor_K0_N_K1
(
karg
.
K
,
karg
.
N
,
karg
.
NPadded
,
karg
.
K0
,
karg
.
StrideB
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
MPadded
,
karg
.
N
,
karg
.
NPadded
,
karg
.
StrideC
);
#endif
const
auto
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
const
auto
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n
);
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment