Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
cf95b944
Commit
cf95b944
authored
Jul 21, 2022
by
Chao Liu
Browse files
add ds
parent
0267bf47
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
4 deletions
+43
-4
include/ck/tensor_operation/gpu/device/device_conv_fwd_multiple_d_xdl_cshuffle.hpp
...on/gpu/device/device_conv_fwd_multiple_d_xdl_cshuffle.hpp
+22
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+21
-0
No files found.
include/ck/tensor_operation/gpu/device/device_conv_fwd_multiple_d_xdl_cshuffle.hpp
View file @
cf95b944
...
...
@@ -643,6 +643,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
InMemoryDataOperationEnum
::
Set
,
AGridDesc_M_K
,
BGridDesc_N_K
,
StaticallyIndexedArray
<
EGridDesc_M_N
,
NumDTensor
>
,
EGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
...
...
@@ -761,12 +762,29 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
block_2_etile_map_
=
Block2ETileMap
{
e_grid_desc_m_n_
};
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k_
,
b_grid_desc_n_k_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k_
,
b_grid_desc_n_k_
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
// populate pointer and desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds
[
i
]);
ds_grid_desc_m_n_
[
i
]
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
ds_n_k_wos_lengths
[
i
],
ds_n_k_wos_strides
[
i
]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
(
i
)
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n_
[
i
]);
});
}
}
...
...
@@ -780,13 +798,11 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
// tensor descriptors
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
// FIXME: don't assume D and E desc are the same type
StaticallyIndexedArray
<
EGridDesc_M_N
,
NumDTensor
>
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
// FIXME: don't assume D and E desc are the same type
StaticallyIndexedArray
<
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>
...
...
@@ -833,6 +849,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
{
...
...
@@ -1016,6 +1033,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
// check Gridwise GEMM
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
);
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
cf95b944
...
...
@@ -37,6 +37,7 @@ template <typename ABDataType, // FIXME: don't assume A/B have same datatype
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
typename
AGridDesc_M_K
,
typename
BGridDesc_N_K
,
typename
DsGridDesc_M_N
,
typename
EGridDesc_M_N
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
...
...
@@ -225,6 +226,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
template
<
typename
Block2ETileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2ETileMap
&
block_2_etile_map
)
{
...
...
@@ -236,11 +238,29 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
// check consistency of desc
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)))
{
return
false
;
}
bool
valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
valid
=
valid
&&
(
M
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I0
)
&&
N
==
ds_grid_desc_m_n
[
i
].
GetLength
(
I1
));
});
if
(
!
valid
)
{
return
false
;
}
// check tile size
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
{
return
false
;
}
// check gridwise gemm pipeline
const
auto
num_k_loop
=
K
/
KPerBlock
;
...
...
@@ -250,6 +270,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
return
false
;
}
// check block-to-E-tile
if
(
!
block_2_etile_map
.
CheckValidity
(
e_grid_desc_m_n
))
{
return
false
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment