Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
cf360b72
Commit
cf360b72
authored
May 19, 2022
by
ltqin
Browse files
create b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3 from parameter
parent
8d4b51ca
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
4 deletions
+17
-4
example/01_gemm/gemm_xdl_skip_lds_fp16.cpp
example/01_gemm/gemm_xdl_skip_lds_fp16.cpp
+1
-1
include/ck/tensor_operation/gpu/device/device_gemm_xdl_skip_lds.hpp
.../tensor_operation/gpu/device/device_gemm_xdl_skip_lds.hpp
+9
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_lds_v2r3.hpp
...operation/gpu/grid/gridwise_gemm_xdlops_skip_lds_v2r3.hpp
+7
-3
No files found.
example/01_gemm/gemm_xdl_skip_lds_fp16.cpp
View file @
cf360b72
...
@@ -84,7 +84,7 @@ int main(int argc, char* argv[])
...
@@ -84,7 +84,7 @@ int main(int argc, char* argv[])
// GEMM shape
// GEMM shape
#if NORMAL_CONFIG
#if NORMAL_CONFIG
ck
::
index_t
M
=
3840
;
ck
::
index_t
M
=
256
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
N
=
4096
;
ck
::
index_t
K
=
4096
;
ck
::
index_t
K
=
4096
;
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_skip_lds.hpp
View file @
cf360b72
...
@@ -265,6 +265,9 @@ struct DeviceGemmXdlSkipLds
...
@@ -265,6 +265,9 @@ struct DeviceGemmXdlSkipLds
block_2_ctile_map_
=
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_
=
GridwiseGemm
::
MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3
(
b_grid_desc_k0_n_k1_
);
}
}
}
}
...
@@ -275,6 +278,8 @@ struct DeviceGemmXdlSkipLds
...
@@ -275,6 +278,8 @@ struct DeviceGemmXdlSkipLds
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_
;
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
...
@@ -331,6 +336,7 @@ struct DeviceGemmXdlSkipLds
...
@@ -331,6 +336,7 @@ struct DeviceGemmXdlSkipLds
CDataType
,
CDataType
,
remove_reference_t
<
DeviceGemmXdlSkipLds
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdlSkipLds
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdlSkipLds
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdlSkipLds
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
...
@@ -348,6 +354,7 @@ struct DeviceGemmXdlSkipLds
...
@@ -348,6 +354,7 @@ struct DeviceGemmXdlSkipLds
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
...
@@ -362,6 +369,7 @@ struct DeviceGemmXdlSkipLds
...
@@ -362,6 +369,7 @@ struct DeviceGemmXdlSkipLds
CDataType
,
CDataType
,
remove_reference_t
<
DeviceGemmXdlSkipLds
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdlSkipLds
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdlSkipLds
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdlSkipLds
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
...
@@ -379,6 +387,7 @@ struct DeviceGemmXdlSkipLds
...
@@ -379,6 +387,7 @@ struct DeviceGemmXdlSkipLds
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_lds_v2r3.hpp
View file @
cf360b72
...
@@ -18,6 +18,7 @@ template <typename GridwiseGemm,
...
@@ -18,6 +18,7 @@ template <typename GridwiseGemm,
typename
FloatC
,
typename
FloatC
,
typename
AGridDesc_K0_M_K1
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
BGridDesc_K0_N_K1
,
typename
BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3
,
typename
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
typename
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
...
@@ -34,6 +35,7 @@ __global__ void
...
@@ -34,6 +35,7 @@ __global__ void
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
...
@@ -49,6 +51,7 @@ __global__ void
...
@@ -49,6 +51,7 @@ __global__ void
p_shared
,
p_shared
,
a_grid_desc_k0_m_k1
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
b_grid_desc_k0_n_k1
,
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -60,6 +63,7 @@ __global__ void
...
@@ -60,6 +63,7 @@ __global__ void
ignore
=
p_c_grid
;
ignore
=
p_c_grid
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
;
ignore
=
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
;
ignore
=
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
...
@@ -482,6 +486,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
...
@@ -482,6 +486,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
using
DefaultBlock2CTileMap
=
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
using
DefaultBlock2CTileMap
=
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
using
BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3
=
decltype
(
MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3
(
BGridDesc_K0_N_K1
{}));
template
<
bool
HasMainK0BlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
template
<
bool
HasMainK0BlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
__device__
static
void
...
@@ -491,6 +497,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
...
@@ -491,6 +497,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
&
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
&
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
...
@@ -575,9 +582,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
...
@@ -575,9 +582,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_lds_v2r3
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
.
GetElementSpaceSize
(),
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
.
GetElementSpaceSize
(),
true
>
{};
true
>
{};
auto
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
=
MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3
(
b_grid_desc_k0_n_k1
);
const
auto
wave_id
=
GetWaveIdx
();
const
auto
wave_id
=
GetWaveIdx
();
const
auto
wave_k_n_id
=
GetWaveKNIdx
(
wave_id
[
I2
]);
const
auto
wave_k_n_id
=
GetWaveKNIdx
(
wave_id
[
I2
]);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment