Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
dbc971be
Commit
dbc971be
authored
Jul 16, 2022
by
wangshaojie6
Browse files
wip for gridwise
parent
6985af40
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
48 additions
and
16 deletions
+48
-16
include/ck/tensor_operation/gpu/grid/gridwise_gemm_gemm_xdl_skip_lds.hpp
...or_operation/gpu/grid/gridwise_gemm_gemm_xdl_skip_lds.hpp
+48
-16
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_gemm_xdl_skip_lds.hpp
View file @
dbc971be
...
@@ -74,6 +74,8 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
...
@@ -74,6 +74,8 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
K0PerBlock
=
Number
<
KPerBlock
/
AK1
>
{};
static
constexpr
auto
BaseMultK0
=
2
;
static
constexpr
auto
BaseMultK0
=
2
;
static
constexpr
auto
MultiK0
=
BaseMultK0
*
1
;
static
constexpr
auto
MultiK0
=
BaseMultK0
*
1
;
...
@@ -81,30 +83,36 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
...
@@ -81,30 +83,36 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
static
constexpr
auto
K1
=
Number
<
AK1
>
{};
static
constexpr
auto
K1
=
Number
<
AK1
>
{};
static
constexpr
index_t
WaveSize
=
64
;
static
constexpr
index_t
WaveSize
=
64
;
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MXdlPerWave
*
MPerXDL
);
static
constexpr
index_t
M0Waves
=
M0PerBlock
/
(
M0XdlPerWave
*
M0PerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXDL
);
static
constexpr
index_t
N0Waves
=
N0PerBlock
/
(
N0XdlPerWave
*
N0PerXDL
);
static
constexpr
auto
xdlops_gemm0
=
XdlopsGemm
<
FloatAB
,
M0PerXDL
,
N0PerXDL
,
K1
>
{};
static
constexpr
index_t
K0PerThread0
=
K0PerBlock
/
xdlops_gemm0
.
K0PerXdlops
;
static
constexpr
index_t
M1Waves
=
M1PerBlock
/
(
M1XdlPerWave
*
M1PerXDL
);
static
constexpr
index_t
N1Waves
=
N1PerBlock
/
(
N1XdlPerWave
*
N1PerXDL
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
K1
>
{};
static
constexpr
auto
xdlops_gemm
1
=
XdlopsGemm
<
FloatAB
,
M
1
PerXDL
,
N
1
PerXDL
,
K1
>
{};
static
constexpr
index_t
K0PerThread
=
K0PerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
K0PerThread
1
=
K0PerBlock
/
xdlops_gemm
1
.
K0PerXdlops
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
{
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
A
K1
;
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0_m_k1
=
[
&
]()
{
constexpr
auto
a_block_desc_k0_m_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
if
constexpr
(
ABlockLdsExtraM
)
{
{
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
*
MultiK0
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
*
MultiK0
>
{},
Number
<
M
0
PerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
make_tuple
(
Number
<
M
0
PerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
}
else
else
{
{
return
make_naive_tensor_descriptor_aligned
(
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
*
MultiK0
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
*
MultiK0
>
{},
Number
<
M
0
PerBlock
>
{},
K1
),
max_lds_align
);
max_lds_align
);
}
}
}();
}();
...
@@ -112,10 +120,34 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
...
@@ -112,10 +120,34 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
return
a_block_desc_k0_m_k1
;
return
a_block_desc_k0_m_k1
;
}
}
__host__
__device__
static
constexpr
auto
GetB1BlockDescriptor_K0PerBlock_NPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
B1K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
b1_block_desc_k0_n_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
*
MultiK0
>
{},
Number
<
N1PerBlock
>
{},
K1
),
make_tuple
(
Number
<
N1PerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
*
MultiK0
>
{},
Number
<
N1PerBlock
>
{},
K1
),
max_lds_align
);
}
}();
return
b1_block_desc_k0_n_k1
;
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
{
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_k0_m_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
a_block_desc_k0_m_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
b1_block_desc_k0_n_k1
=
GetB1BlockDescriptor_K0PerBlock_NPerBlock_K1
();
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
K1
;
...
@@ -179,7 +211,7 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
...
@@ -179,7 +211,7 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
const
index_t
grid_size
=
(
M
/
M
0
PerBlock
)
*
(
N
/
N
1
PerBlock
);
return
grid_size
;
return
grid_size
;
}
}
...
@@ -193,21 +225,21 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
...
@@ -193,21 +225,21 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3
(
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
)
MakeB
0
GridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3
(
const
B
0
GridDesc_K0_N_K1
&
b
0
_grid_desc_k0_n_k1
)
{
{
const
auto
K0
=
b_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
K0
=
b
0
_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
N
=
b
0
_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1
=
transform_tensor_descriptor
(
const
auto
b
0
_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1
=
transform_tensor_descriptor
(
b_grid_desc_k0_n_k1
,
b
0
_grid_desc_k0_n_k1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
/
K0PerBlock
,
xdlops_gemm
.
K0PerXdlops
,
K0PerThread
)),
make_tuple
(
K0
/
K0PerBlock
,
xdlops_gemm
0
.
K0PerXdlops
,
K0PerThread
)),
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
N
/
(
NXdlPerWave
*
NWaves
*
NPerXDL
),
NXdlPerWave
,
NWaves
,
NPerXDL
)),
N
/
(
NXdlPerWave
*
NWaves
*
NPerXDL
),
NXdlPerWave
,
NWaves
,
NPerXDL
)),
make_pass_through_transform
(
K1
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
,
6
>
{},
Sequence
<
7
>
{}));
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
,
6
>
{},
Sequence
<
7
>
{}));
return
b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1
;
return
b
0
_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1
;
}
}
__device__
static
auto
GetWaveIdx
()
__device__
static
auto
GetWaveIdx
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment