Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0530fd66
Commit
0530fd66
authored
Jul 20, 2022
by
Chao Liu
Browse files
update gemm multi-d
parent
05c484e2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
138 additions
and
201 deletions
+138
-201
include/ck/tensor_operation/gpu/device/device_conv_fwd_multiple_d.hpp
...ensor_operation/gpu/device/device_conv_fwd_multiple_d.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_convnd_fwd_multiple_d_nwc_kxc_nwk_xdl_cshuffle.hpp
...device_convnd_fwd_multiple_d_nwc_kxc_nwk_xdl_cshuffle.hpp
+59
-160
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+78
-40
No files found.
include/ck/tensor_operation/gpu/device/device_conv_fwd_multiple_d.hpp
View file @
0530fd66
...
@@ -11,7 +11,7 @@ namespace ck {
...
@@ -11,7 +11,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
//
GEMM
:
//
Convolution Forward
:
// input : input image A[N, Hi, Wi, C],
// input : input image A[N, Hi, Wi, C],
// input : weight B[K, Y, X, C],
// input : weight B[K, Y, X, C],
// input : D0[N, Ho, Wo, K], D1[N, Ho, Wo, K], ...
// input : D0[N, Ho, Wo, K], D1[N, Ho, Wo, K], ...
...
...
include/ck/tensor_operation/gpu/device/device_convnd_fwd_multiple_d_nwc_kxc_nwk_xdl_cshuffle.hpp
View file @
0530fd66
...
@@ -77,18 +77,18 @@ __global__ void
...
@@ -77,18 +77,18 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
block_2_etile_map
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
p_e_grid
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
cde_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
block_2_etile_map
;
ignore
=
block_2_etile_map
;
#endif
#endif
}
}
}
// namespace
}
// namespace
...
@@ -197,18 +197,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -197,18 +197,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
wei_gemmn_gemmk_grid_desc
=
const
auto
wei_gemmn_gemmk_grid_desc
=
matrix_padder
.
PadBDescriptor_N_K
(
wei_k_yxc_grid_desc
);
matrix_padder
.
PadBDescriptor_N_K
(
wei_k_yxc_grid_desc
);
const
auto
GemmN
=
wei_gemmn_gemmk_grid_desc
.
GetLength
(
I0
);
return
wei_gemmn_gemmk_grid_desc
;
const
auto
GemmK
=
wei_gemmn_gemmk_grid_desc
.
GetLength
(
I1
);
const
index_t
GemmK0
=
GemmK
/
GemmK1Number
;
// wei_gemmk0_gemmn_gemmk1_grid_desc
return
transform_tensor_descriptor
(
wei_gemmn_gemmk_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
static
auto
GetOutputTensorDescriptor
(
index_t
GemmMRaw
,
index_t
GemmN
)
static
auto
GetOutputTensorDescriptor
(
index_t
GemmMRaw
,
index_t
GemmN
)
...
@@ -250,18 +239,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -250,18 +239,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
in_gemmm_gemmk_grid_desc
=
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmk_grid_desc
);
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmk_grid_desc
);
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
return
in_gemmm_gemmk_grid_desc
;
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
// in_gemmk0_gemmm_gemmk1_grid_desc
return
transform_tensor_descriptor
(
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
...
@@ -286,19 +264,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -286,19 +264,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
in_gemmm_gemmk_grid_desc
=
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_grid_desc
);
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_grid_desc
);
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
return
in_gemmm_gemmk_grid_desc
;
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
in_gemmk0_gemmm_gemmk1_grid_desc
;
}
}
else
else
{
{
...
@@ -337,19 +303,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -337,19 +303,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
in_gemmm_gemmk_grid_desc
=
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmk_grid_desc
);
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmk_grid_desc
);
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
return
in_gemmm_gemmk_grid_desc
;
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmM
),
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
return
in_gemmk0_gemmm_gemmk1_grid_desc
;
}
}
}
}
...
@@ -384,18 +338,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -384,18 +338,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
in_gemmm_gemmk_grid_desc
=
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_grid_desc
);
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_grid_desc
);
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
return
in_gemmm_gemmk_grid_desc
;
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
// in_gemmk0_gemmm_gemmk1_grid_desc
return
transform_tensor_descriptor
(
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
...
@@ -422,19 +365,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -422,19 +365,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
in_gemmm_gemmk_grid_desc
=
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmk_grid_desc
);
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmk_grid_desc
);
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
return
in_gemmm_gemmk_grid_desc
;
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
in_gemmk0_gemmm_gemmk1_grid_desc
;
}
}
else
else
{
{
...
@@ -482,19 +413,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -482,19 +413,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
in_gemmm_gemmk_grid_desc
=
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmk_grid_desc
);
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmk_grid_desc
);
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
return
in_gemmm_gemmk_grid_desc
;
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmM
),
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
return
in_gemmk0_gemmm_gemmk1_grid_desc
;
}
}
}
}
...
@@ -532,18 +451,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -532,18 +451,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
in_gemmm_gemmk_grid_desc
=
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_grid_desc
);
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_grid_desc
);
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
return
in_gemmm_gemmk_grid_desc
;
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
// in_gemmk0_gemmm_gemmk1_grid_desc
return
transform_tensor_descriptor
(
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
...
@@ -573,19 +481,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -573,19 +481,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
in_gemmm_gemmk_grid_desc
=
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_grid_desc
);
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_grid_desc
);
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
return
in_gemmm_gemmk_grid_desc
;
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
in_gemmk0_gemmm_gemmk1_grid_desc
;
}
}
else
else
{
{
...
@@ -646,19 +542,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -646,19 +542,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
auto
in_gemmm_gemmk_grid_desc
=
const
auto
in_gemmm_gemmk_grid_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_grid_desc
);
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_grid_desc
);
const
auto
GemmM
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I0
);
return
in_gemmm_gemmk_grid_desc
;
const
auto
GemmK
=
in_gemmm_gemmk_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK0
=
GemmK
/
GemmK1Number
;
const
auto
in_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmm_gemmk_grid_desc
,
make_tuple
(
make_pass_through_transform
(
GemmM
),
make_unmerge_transform
(
make_tuple
(
GemmK0
,
GemmK1Number
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
return
in_gemmk0_gemmm_gemmk1_grid_desc
;
}
}
}
}
...
@@ -696,11 +580,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -696,11 +580,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
const
index_t
GemmNRaw
=
K
;
const
index_t
GemmNRaw
=
K
;
const
index_t
GemmKRaw
=
GetGemmKRaw
(
C
,
filter_spatial_lengths
);
const
index_t
GemmKRaw
=
GetGemmKRaw
(
C
,
filter_spatial_lengths
);
// TODO: remove
assert
(
GemmKRaw
%
GemmK1Number
==
0
);
// A:
// A:
const
auto
in_
gemmk0_
gemmm_gemmk
1
_grid_desc
=
const
auto
in_gemmm_gemmk_grid_desc
=
GetInputTensorDescriptor
<
NDimSpatial
>
(
N
,
GetInputTensorDescriptor
<
NDimSpatial
>
(
N
,
C
,
C
,
GemmMRaw
,
GemmMRaw
,
...
@@ -714,15 +595,13 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -714,15 +595,13 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
input_right_pads
);
input_right_pads
);
// B:
// B:
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
const
auto
wei_gemmn_gemmk_grid_desc
=
GetWeightTensorDescriptor
(
GemmNRaw
,
GemmKRaw
);
GetWeightTensorDescriptor
(
GemmNRaw
,
GemmKRaw
);
// E:
// E:
const
auto
out_gemmm_gemmn_grid_desc
=
GetOutputTensorDescriptor
(
GemmMRaw
,
GemmNRaw
);
const
auto
out_gemmm_gemmn_grid_desc
=
GetOutputTensorDescriptor
(
GemmMRaw
,
GemmNRaw
);
return
make_tuple
(
in_gemmk0_gemmm_gemmk1_grid_desc
,
return
make_tuple
(
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmk_grid_desc
,
wei_gemmn_gemmk_grid_desc
,
out_gemmm_gemmn_grid_desc
);
out_gemmm_gemmn_grid_desc
);
}
}
template
<
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
template
<
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
...
@@ -748,9 +627,9 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -748,9 +627,9 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
using
ABEGridDescs
=
decltype
(
GetABEGridDesc
<
NDimSpatial
>
());
using
ABEGridDescs
=
decltype
(
GetABEGridDesc
<
NDimSpatial
>
());
using
AGridDesc_
AK0_M_AK1
=
remove_cvref_t
<
decltype
(
ABEGridDescs
{}[
I0
])
>
;
using
AGridDesc_
M_K
=
remove_cvref_t
<
decltype
(
ABEGridDescs
{}[
I0
])
>
;
using
BGridDesc_
BK0_N_BK1
=
remove_cvref_t
<
decltype
(
ABEGridDescs
{}[
I1
])
>
;
using
BGridDesc_
N_K
=
remove_cvref_t
<
decltype
(
ABEGridDescs
{}[
I1
])
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABEGridDescs
{}[
I2
])
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABEGridDescs
{}[
I2
])
>
;
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
<
using
GridwiseGemm
=
GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
<
...
@@ -763,8 +642,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -763,8 +642,8 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
BElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_
AK0_M_AK1
,
AGridDesc_
M_K
,
BGridDesc_
BK0_N_BK1
,
BGridDesc_
N_K
,
EGridDesc_M_N
,
EGridDesc_M_N
,
NumGemmKPrefetchStage
,
NumGemmKPrefetchStage
,
BlockSize
,
BlockSize
,
...
@@ -799,11 +678,12 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -799,11 +678,12 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
CDEBlockTransferScalarPerVector_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
LoopSched
>
;
#if 0
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
using Block2ETileMap = BlockToCTileMap_M00_N0_M01<MPerBlock, NPerBlock, EGridDesc_M_N>;
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
#else
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
Block2ETileMap
=
typename
GridwiseGemm
::
DefaultBlock2ETileMap
;
using
Block2ETileMap
=
typename
GridwiseGemm
::
DefaultBlock2ETileMap
;
#endif
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
...
@@ -856,9 +736,14 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -856,9 +736,14 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
input_left_pads
,
input_left_pads
,
input_right_pads
);
input_right_pads
);
a_grid_desc_ak0_m_ak1_
=
descs
[
I0
];
const
auto
a_grid_desc_m_k
=
descs
[
I0
];
b_grid_desc_bk0_n_bk1_
=
descs
[
I1
];
const
auto
b_grid_desc_n_k
=
descs
[
I1
];
e_grid_desc_m_n_
=
descs
[
I2
];
e_grid_desc_m_n_
=
descs
[
I2
];
a_grid_desc_ak0_m_ak1_
=
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k
);
b_grid_desc_bk0_n_bk1_
=
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k
);
block_2_etile_map_
=
Block2ETileMap
{
e_grid_desc_m_n_
};
block_2_etile_map_
=
Block2ETileMap
{
e_grid_desc_m_n_
};
...
@@ -917,7 +802,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -917,7 +802,7 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
#if
0
#if
1
{
{
std
::
cout
<<
"arg.a_grid_desc_ak0_m_ak1_{"
std
::
cout
<<
"arg.a_grid_desc_ak0_m_ak1_{"
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
<<
", "
...
@@ -1010,6 +895,20 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
...
@@ -1010,6 +895,20 @@ struct DeviceConvNdFwdMultipleD_NwcKxcNwk_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
#if 1
{
std
::
cout
<<
"arg.a_grid_desc_ak0_m_ak1_{"
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_bk0_n_bk1_{"
<<
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.e_grid_desc_m_n_{ "
<<
arg
.
e_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
e_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
#endif
if
(
ck
::
get_device_name
()
==
"gfx908"
)
if
(
ck
::
get_device_name
()
==
"gfx908"
)
{
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
float
>
||
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
float
>
||
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
0530fd66
...
@@ -18,8 +18,8 @@
...
@@ -18,8 +18,8 @@
namespace
ck
{
namespace
ck
{
// GEMM:
// GEMM:
// input : A[AK0, M, AK1]
// input : A[AK0
PerBlock
, M, AK1]
// input : B[AK0, N, AK1]
// input : B[AK0
PerBlock
, N, AK1]
// input : D0[M, N], D1[M, N], ...
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// output : E[M, N]
// C = a_op(A) * b_op(B)
// C = a_op(A) * b_op(B)
...
@@ -35,8 +35,8 @@ template <typename FloatAB,
...
@@ -35,8 +35,8 @@ template <typename FloatAB,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
CDEElementwiseOperation
,
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
typename
AGridDesc_
AK0_M_AK1
,
typename
AGridDesc_
M_K
,
typename
BGridDesc_
BK0_N_BK1
,
typename
BGridDesc_
N_K
,
typename
EGridDesc_M_N
,
typename
EGridDesc_M_N
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
BlockSize
,
...
@@ -84,10 +84,10 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -84,10 +84,10 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
// K1 should be Number<...>
static
constexpr
auto
AK
0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
AK
1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK
0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
BK
1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
AK
1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
AK
0PerBlock
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK
1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
BK
0PerBlock
=
Number
<
KPerBlock
/
BK1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
@@ -97,7 +97,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -97,7 +97,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
{
{
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
AK0
PerBlock
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
}
}
...
@@ -105,7 +105,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -105,7 +105,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
{
{
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
BK0
PerBlock
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
}
...
@@ -164,8 +164,65 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -164,8 +164,65 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
c_block_size
*
sizeof
(
FloatCShuffle
));
c_block_size
*
sizeof
(
FloatCShuffle
));
}
}
__host__
__device__
static
constexpr
auto
MakeDefaultAGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
)
{
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
__host__
__device__
static
constexpr
auto
MakeDefaultBGridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
)
{
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
__host__
__device__
static
constexpr
auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
const
auto
M
=
e_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
e_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
e_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
e_grid_desc_mblock_mperblock_nblock_nperblock
;
}
// return block_id to E matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
EGridDesc_M_N
>
(
e_grid_desc_m_n
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2ETileMap
>
template
<
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
Block2ETileMap
>
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
...
@@ -210,32 +267,10 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -210,32 +267,10 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
}
__host__
__device__
static
constexpr
auto
using
DefaultAGridDesc_AK0_M_AK1
=
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
remove_cvref_t
<
decltype
(
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
{
using
DefaultBGridDesc_BK0_N_BK1
=
const
auto
M
=
e_grid_desc_m_n
.
GetLength
(
I0
);
remove_cvref_t
<
decltype
(
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
const
auto
N
=
e_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
e_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
e_grid_desc_mblock_mperblock_nblock_nperblock
;
}
// return block_id to E matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
EGridDesc_M_N
>
(
e_grid_desc_m_n
);
}
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
...
@@ -245,7 +280,10 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -245,7 +280,10 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
using
DsGridPointer
=
decltype
(
MakeDsGridPointer
());
using
DsGridPointer
=
decltype
(
MakeDsGridPointer
());
template
<
bool
HasMainKBlockLoop
,
typename
Block2ETileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
Block2ETileMap
>
__device__
static
void
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
@@ -316,7 +354,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -316,7 +354,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
AElementwiseOperation
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
Sequence
<
AK0
PerBlock
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
...
@@ -347,7 +385,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -347,7 +385,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
BElementwiseOperation
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
Sequence
<
BK0
PerBlock
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment